Fix mknfold using new StratifiedKFold API (#1660)
This commit is contained in:
committed by
Tianqi Chen
parent
b56c6097d9
commit
63829d656c
@@ -251,7 +251,7 @@ def test_sklearn_plotting():
|
||||
def test_sklearn_nfolds_cv():
|
||||
tm._skip_if_no_sklearn()
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.cross_validation import StratifiedKFold
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
|
||||
digits = load_digits(3)
|
||||
X = digits['data']
|
||||
@@ -269,10 +269,10 @@ def test_sklearn_nfolds_cv():
|
||||
|
||||
seed = 2016
|
||||
nfolds = 5
|
||||
skf = StratifiedKFold(y, n_folds=nfolds, shuffle=True, random_state=seed)
|
||||
skf = StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=seed)
|
||||
|
||||
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed)
|
||||
cv2 = xgb.cv(params, dm, num_boost_round=10, folds=skf, seed=seed)
|
||||
cv2 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, folds=skf, seed=seed)
|
||||
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
|
||||
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
|
||||
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
|
||||
|
||||
Reference in New Issue
Block a user