Fix mknfold using new StratifiedKFold API (#1660)

This commit is contained in:
Yuan (Terry) Tang
2016-10-12 16:43:37 -05:00
committed by Tianqi Chen
parent b56c6097d9
commit 63829d656c
3 changed files with 8 additions and 10 deletions

View File

@@ -251,7 +251,7 @@ def test_sklearn_plotting():
def test_sklearn_nfolds_cv():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
from sklearn.cross_validation import StratifiedKFold
from sklearn.model_selection import StratifiedKFold
digits = load_digits(3)
X = digits['data']
@@ -269,10 +269,10 @@ def test_sklearn_nfolds_cv():
seed = 2016
nfolds = 5
skf = StratifiedKFold(y, n_folds=nfolds, shuffle=True, random_state=seed)
skf = StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=seed)
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, folds=skf, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, folds=skf, seed=seed)
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]