option to shuffle data in mknfolds (#1459)
* option to shuffle data in mknfolds * removed possibility to run as stand alone test * split function def in 2 lines for lint * option to shuffle data in mknfolds * removed possibility to run as stand alone test * split function def in 2 lines for lint
This commit is contained in:
parent
b49b339183
commit
fb0fc0c580
@ -222,7 +222,8 @@ class CVPack(object):
|
|||||||
return self.bst.eval_set(self.watchlist, iteration, feval)
|
return self.bst.eval_set(self.watchlist, iteration, feval)
|
||||||
|
|
||||||
|
|
||||||
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, folds=None):
|
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
||||||
|
folds=None, shuffle=True):
|
||||||
"""
|
"""
|
||||||
Make an n-fold list of CVPack from random indices.
|
Make an n-fold list of CVPack from random indices.
|
||||||
"""
|
"""
|
||||||
@ -230,9 +231,12 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
|||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
if stratified is False and folds is None:
|
if stratified is False and folds is None:
|
||||||
randidx = np.random.permutation(dall.num_row())
|
if shuffle is True:
|
||||||
kstep = int(len(randidx) / nfold)
|
idx = np.random.permutation(dall.num_row())
|
||||||
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
|
else:
|
||||||
|
idx = np.arange(dall.num_row())
|
||||||
|
kstep = int(len(idx) / nfold)
|
||||||
|
idset = [idx[(i * kstep): min(len(idx), (i + 1) * kstep)] for i in range(nfold)]
|
||||||
elif folds is not None and isinstance(folds, list):
|
elif folds is not None and isinstance(folds, list):
|
||||||
idset = [x[1] for x in folds]
|
idset = [x[1] for x in folds]
|
||||||
nfold = len(idset)
|
nfold = len(idset)
|
||||||
@ -289,7 +293,7 @@ def aggcv(rlist):
|
|||||||
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
||||||
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
|
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
|
||||||
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
|
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
|
||||||
seed=0, callbacks=None):
|
seed=0, callbacks=None, shuffle=True):
|
||||||
# pylint: disable = invalid-name
|
# pylint: disable = invalid-name
|
||||||
"""Cross-validation with given parameters.
|
"""Cross-validation with given parameters.
|
||||||
|
|
||||||
@ -339,6 +343,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
List of callback functions that are applied at end of each iteration.
|
List of callback functions that are applied at end of each iteration.
|
||||||
It is possible to use predefined callbacks by using xgb.callback module.
|
It is possible to use predefined callbacks by using xgb.callback module.
|
||||||
Example: [xgb.callback.reset_learning_rate(custom_rates)]
|
Example: [xgb.callback.reset_learning_rate(custom_rates)]
|
||||||
|
shuffle : bool
|
||||||
|
Shuffle data before creating folds.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -367,7 +373,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
params.pop("eval_metric", None)
|
params.pop("eval_metric", None)
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds)
|
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc,
|
||||||
|
stratified, folds, shuffle)
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
callbacks = [] if callbacks is None else callbacks
|
callbacks = [] if callbacks is None else callbacks
|
||||||
|
|||||||
@ -241,3 +241,12 @@ class TestBasic(unittest.TestCase):
|
|||||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
||||||
assert isinstance(cv, dict)
|
assert isinstance(cv, dict)
|
||||||
assert len(cv) == (4)
|
assert len(cv) == (4)
|
||||||
|
|
||||||
|
def test_cv_no_shuffle(self):
|
||||||
|
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
|
||||||
|
|
||||||
|
# return np.ndarray
|
||||||
|
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
|
||||||
|
assert isinstance(cv, dict)
|
||||||
|
assert len(cv) == (4)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user