allow arbitrary cross validation fold indices (#3353)
* allow arbitrary cross validation fold indices - use training indices passed to `folds` parameter in `training.cv` - update doc string * add tests for arbitrary fold indices
This commit is contained in:
parent
594bcea83e
commit
18813a26ab
@ -231,22 +231,39 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
||||
np.random.seed(seed)
|
||||
|
||||
if stratified is False and folds is None:
|
||||
# Do standard k-fold cross validation
|
||||
if shuffle is True:
|
||||
idx = np.random.permutation(dall.num_row())
|
||||
else:
|
||||
idx = np.arange(dall.num_row())
|
||||
idset = np.array_split(idx, nfold)
|
||||
elif folds is not None and isinstance(folds, list):
|
||||
idset = [x[1] for x in folds]
|
||||
nfold = len(idset)
|
||||
out_idset = np.array_split(idx, nfold)
|
||||
in_idset = [
|
||||
np.concatenate([out_idset[i] for i in range(nfold) if k != i])
|
||||
for k in range(nfold)
|
||||
]
|
||||
elif folds is not None:
|
||||
# Use user specified custom split using indices
|
||||
try:
|
||||
in_idset = [x[0] for x in folds]
|
||||
out_idset = [x[1] for x in folds]
|
||||
except TypeError:
|
||||
# Custom stratification using Sklearn KFoldSplit object
|
||||
splits = list(folds.split(X=dall.get_label(), y=dall.get_label()))
|
||||
in_idset = [x[0] for x in splits]
|
||||
out_idset = [x[1] for x in splits]
|
||||
nfold = len(out_idset)
|
||||
else:
|
||||
# Do standard stratefied shuffle k-fold split
|
||||
sfk = XGBStratifiedKFold(n_splits=nfold, shuffle=True, random_state=seed)
|
||||
idset = [x[1] for x in sfk.split(X=dall.get_label(), y=dall.get_label())]
|
||||
splits = list(sfk.split(X=dall.get_label(), y=dall.get_label()))
|
||||
in_idset = [x[0] for x in splits]
|
||||
out_idset = [x[1] for x in splits]
|
||||
nfold = len(out_idset)
|
||||
|
||||
ret = []
|
||||
for k in range(nfold):
|
||||
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
|
||||
dtest = dall.slice(idset[k])
|
||||
dtrain = dall.slice(in_idset[k])
|
||||
dtest = dall.slice(out_idset[k])
|
||||
# run preprocessing on the data set if needed
|
||||
if fpreproc is not None:
|
||||
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
|
||||
@ -308,8 +325,13 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
Number of folds in CV.
|
||||
stratified : bool
|
||||
Perform stratified sampling.
|
||||
folds : a KFold or StratifiedKFold instance
|
||||
Sklearn KFolds or StratifiedKFolds.
|
||||
folds : a KFold or StratifiedKFold instance or list of fold indices
|
||||
Sklearn KFolds or StratifiedKFolds object.
|
||||
Alternatively may explicitly pass sample indices for each fold.
|
||||
For `n` folds, `folds` should be a length `n` list of tuples.
|
||||
Each tuple is `(in,out)` where `in` is a list of indices to be used
|
||||
as the training samples for the `n`th fold and `out` is a list of
|
||||
indices to be used as the testing samples for the `n`th fold.
|
||||
metrics : string or list of strings
|
||||
Evaluation metrics to be watched in CV.
|
||||
obj : function
|
||||
|
||||
@ -1,4 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
try:
|
||||
# python 2
|
||||
from StringIO import StringIO
|
||||
except ImportError:
|
||||
# python 3
|
||||
from io import StringIO
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
@ -8,6 +16,21 @@ dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def captured_output():
|
||||
"""
|
||||
Reassign stdout temporarily in order to test printed statements
|
||||
Taken from: https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python
|
||||
"""
|
||||
new_out, new_err = StringIO(), StringIO()
|
||||
old_out, old_err = sys.stdout, sys.stderr
|
||||
try:
|
||||
sys.stdout, sys.stderr = new_out, new_err
|
||||
yield sys.stdout, sys.stderr
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_out, old_err
|
||||
|
||||
|
||||
class TestBasic(unittest.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
@ -238,3 +261,41 @@ class TestBasic(unittest.TestCase):
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == (4)
|
||||
|
||||
def test_cv_explicit_fold_indices(self):
|
||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
folds = [
|
||||
# Train Test
|
||||
([1, 3], [5, 8]),
|
||||
([7, 9], [23, 43]),
|
||||
]
|
||||
|
||||
# return np.ndarray
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, folds=folds, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == (4)
|
||||
|
||||
def test_cv_explicit_fold_indices_labels(self):
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'reg:linear'}
|
||||
N = 100
|
||||
F = 3
|
||||
dm = xgb.DMatrix(data=np.random.randn(N, F), label=np.arange(N))
|
||||
folds = [
|
||||
# Train Test
|
||||
([1, 3], [5, 8]),
|
||||
([7, 9], [23, 43, 11]),
|
||||
]
|
||||
|
||||
# Use callback to log the test labels in each fold
|
||||
def cb(cbackenv):
|
||||
print([fold.dtest.get_label() for fold in cbackenv.cvfolds])
|
||||
|
||||
# Run cross validation and capture standard out to test callback result
|
||||
with captured_output() as (out, err):
|
||||
xgb.cv(
|
||||
params, dm, num_boost_round=1, folds=folds, callbacks=[cb],
|
||||
as_pandas=False
|
||||
)
|
||||
output = out.getvalue().strip()
|
||||
assert output == '[array([5., 8.], dtype=float32), array([23., 43., 11.], dtype=float32)]'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user