option to shuffle data in mknfolds (#1459)

* option to shuffle data in mknfolds

* removed possibility to run as stand alone test

* split function def in 2 lines for lint

* option to shuffle data in mknfolds

* removed possibility to run as stand alone test

* split function def in 2 lines for lint
This commit is contained in:
jokari69 2016-12-22 17:53:30 -06:00 committed by Yuan (Terry) Tang
parent b49b339183
commit fb0fc0c580
2 changed files with 22 additions and 6 deletions

View File

@ -222,7 +222,8 @@ class CVPack(object):
return self.bst.eval_set(self.watchlist, iteration, feval) return self.bst.eval_set(self.watchlist, iteration, feval)
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, folds=None): def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
folds=None, shuffle=True):
""" """
Make an n-fold list of CVPack from random indices. Make an n-fold list of CVPack from random indices.
""" """
@ -230,9 +231,12 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
np.random.seed(seed) np.random.seed(seed)
if stratified is False and folds is None: if stratified is False and folds is None:
randidx = np.random.permutation(dall.num_row()) if shuffle is True:
kstep = int(len(randidx) / nfold) idx = np.random.permutation(dall.num_row())
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)] else:
idx = np.arange(dall.num_row())
kstep = int(len(idx) / nfold)
idset = [idx[(i * kstep): min(len(idx), (i + 1) * kstep)] for i in range(nfold)]
elif folds is not None and isinstance(folds, list): elif folds is not None and isinstance(folds, list):
idset = [x[1] for x in folds] idset = [x[1] for x in folds]
nfold = len(idset) nfold = len(idset)
@ -289,7 +293,7 @@ def aggcv(rlist):
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None, def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None, metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
seed=0, callbacks=None): seed=0, callbacks=None, shuffle=True):
# pylint: disable = invalid-name # pylint: disable = invalid-name
"""Cross-validation with given parameters. """Cross-validation with given parameters.
@ -339,6 +343,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module. It is possible to use predefined callbacks by using xgb.callback module.
Example: [xgb.callback.reset_learning_rate(custom_rates)] Example: [xgb.callback.reset_learning_rate(custom_rates)]
shuffle : bool
Shuffle data before creating folds.
Returns Returns
------- -------
@ -367,7 +373,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
params.pop("eval_metric", None) params.pop("eval_metric", None)
results = {} results = {}
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds) cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc,
stratified, folds, shuffle)
# setup callbacks # setup callbacks
callbacks = [] if callbacks is None else callbacks callbacks = [] if callbacks is None else callbacks

View File

@ -241,3 +241,12 @@ class TestBasic(unittest.TestCase):
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False) cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
assert isinstance(cv, dict) assert isinstance(cv, dict)
assert len(cv) == (4) assert len(cv) == (4)
def test_cv_no_shuffle(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
# return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
assert isinstance(cv, dict)
assert len(cv) == (4)