Fix mknfold using new StratifiedKFold API (#1660)

This commit is contained in:
Yuan (Terry) Tang
2016-10-12 16:43:37 -05:00
committed by Tianqi Chen
parent b56c6097d9
commit 63829d656c
3 changed files with 8 additions and 10 deletions

View File

@@ -232,14 +232,12 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
randidx = np.random.permutation(dall.num_row())
kstep = int(len(randidx) / nfold)
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
elif folds is not None:
elif folds is not None and isinstance(folds, list):
idset = [x[1] for x in folds]
nfold = len(idset)
else:
idset = [x[1] for x in XGBStratifiedKFold(dall.get_label(),
n_folds=nfold,
shuffle=True,
random_state=seed)]
sfk = XGBStratifiedKFold(n_splits=nfold, shuffle=True, random_state=seed)
idset = [x[1] for x in sfk.split(X=dall.get_label(), y=dall.get_label())]
ret = []
for k in range(nfold):