allow arbitrary cross validation fold indices (#3353)
* allow arbitrary cross validation fold indices - use training indices passed to `folds` parameter in `training.cv` - update doc string * add tests for arbitrary fold indices
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
594bcea83e
commit
18813a26ab
@@ -231,22 +231,39 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
||||
np.random.seed(seed)
|
||||
|
||||
if stratified is False and folds is None:
|
||||
# Do standard k-fold cross validation
|
||||
if shuffle is True:
|
||||
idx = np.random.permutation(dall.num_row())
|
||||
else:
|
||||
idx = np.arange(dall.num_row())
|
||||
idset = np.array_split(idx, nfold)
|
||||
elif folds is not None and isinstance(folds, list):
|
||||
idset = [x[1] for x in folds]
|
||||
nfold = len(idset)
|
||||
out_idset = np.array_split(idx, nfold)
|
||||
in_idset = [
|
||||
np.concatenate([out_idset[i] for i in range(nfold) if k != i])
|
||||
for k in range(nfold)
|
||||
]
|
||||
elif folds is not None:
|
||||
# Use user specified custom split using indices
|
||||
try:
|
||||
in_idset = [x[0] for x in folds]
|
||||
out_idset = [x[1] for x in folds]
|
||||
except TypeError:
|
||||
# Custom stratification using Sklearn KFoldSplit object
|
||||
splits = list(folds.split(X=dall.get_label(), y=dall.get_label()))
|
||||
in_idset = [x[0] for x in splits]
|
||||
out_idset = [x[1] for x in splits]
|
||||
nfold = len(out_idset)
|
||||
else:
|
||||
# Do standard stratefied shuffle k-fold split
|
||||
sfk = XGBStratifiedKFold(n_splits=nfold, shuffle=True, random_state=seed)
|
||||
idset = [x[1] for x in sfk.split(X=dall.get_label(), y=dall.get_label())]
|
||||
splits = list(sfk.split(X=dall.get_label(), y=dall.get_label()))
|
||||
in_idset = [x[0] for x in splits]
|
||||
out_idset = [x[1] for x in splits]
|
||||
nfold = len(out_idset)
|
||||
|
||||
ret = []
|
||||
for k in range(nfold):
|
||||
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
|
||||
dtest = dall.slice(idset[k])
|
||||
dtrain = dall.slice(in_idset[k])
|
||||
dtest = dall.slice(out_idset[k])
|
||||
# run preprocessing on the data set if needed
|
||||
if fpreproc is not None:
|
||||
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
|
||||
@@ -308,8 +325,13 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
Number of folds in CV.
|
||||
stratified : bool
|
||||
Perform stratified sampling.
|
||||
folds : a KFold or StratifiedKFold instance
|
||||
Sklearn KFolds or StratifiedKFolds.
|
||||
folds : a KFold or StratifiedKFold instance or list of fold indices
|
||||
Sklearn KFolds or StratifiedKFolds object.
|
||||
Alternatively may explicitly pass sample indices for each fold.
|
||||
For `n` folds, `folds` should be a length `n` list of tuples.
|
||||
Each tuple is `(in,out)` where `in` is a list of indices to be used
|
||||
as the training samples for the `n`th fold and `out` is a list of
|
||||
indices to be used as the testing samples for the `n`th fold.
|
||||
metrics : string or list of strings
|
||||
Evaluation metrics to be watched in CV.
|
||||
obj : function
|
||||
|
||||
Reference in New Issue
Block a user