Fix mknfold using new StratifiedKFold API (#1660)

This commit is contained in:
Yuan (Terry) Tang 2016-10-12 16:43:37 -05:00 committed by Tianqi Chen
parent b56c6097d9
commit 63829d656c
3 changed files with 8 additions and 10 deletions

View File

@ -46,7 +46,7 @@ except ImportError:
try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder # noqa
from sklearn.preprocessing import LabelEncoder
try:
from sklearn.model_selection import KFold, StratifiedKFold
except ImportError:

View File

@ -232,14 +232,12 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
randidx = np.random.permutation(dall.num_row())
kstep = int(len(randidx) / nfold)
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
elif folds is not None:
elif folds is not None and isinstance(folds, list):
idset = [x[1] for x in folds]
nfold = len(idset)
else:
idset = [x[1] for x in XGBStratifiedKFold(dall.get_label(),
n_folds=nfold,
shuffle=True,
random_state=seed)]
sfk = XGBStratifiedKFold(n_splits=nfold, shuffle=True, random_state=seed)
idset = [x[1] for x in sfk.split(X=dall.get_label(), y=dall.get_label())]
ret = []
for k in range(nfold):

View File

@ -251,7 +251,7 @@ def test_sklearn_plotting():
def test_sklearn_nfolds_cv():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
from sklearn.cross_validation import StratifiedKFold
from sklearn.model_selection import StratifiedKFold
digits = load_digits(3)
X = digits['data']
@ -269,10 +269,10 @@ def test_sklearn_nfolds_cv():
seed = 2016
nfolds = 5
skf = StratifiedKFold(y, n_folds=nfolds, shuffle=True, random_state=seed)
skf = StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=seed)
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, folds=skf, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, folds=skf, seed=seed)
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]