Fix mknfold using new StratifiedKFold API (#1660)
This commit is contained in:
committed by
Tianqi Chen
parent
b56c6097d9
commit
63829d656c
@@ -232,14 +232,12 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
||||
randidx = np.random.permutation(dall.num_row())
|
||||
kstep = int(len(randidx) / nfold)
|
||||
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
|
||||
elif folds is not None:
|
||||
elif folds is not None and isinstance(folds, list):
|
||||
idset = [x[1] for x in folds]
|
||||
nfold = len(idset)
|
||||
else:
|
||||
idset = [x[1] for x in XGBStratifiedKFold(dall.get_label(),
|
||||
n_folds=nfold,
|
||||
shuffle=True,
|
||||
random_state=seed)]
|
||||
sfk = XGBStratifiedKFold(n_splits=nfold, shuffle=True, random_state=seed)
|
||||
idset = [x[1] for x in sfk.split(X=dall.get_label(), y=dall.get_label())]
|
||||
|
||||
ret = []
|
||||
for k in range(nfold):
|
||||
|
||||
Reference in New Issue
Block a user