allow arbitrary cross validation fold indices (#3353)

* allow arbitrary cross validation fold indices

 - use training indices passed to `folds` parameter in `training.cv`
 - update doc string

* add tests for arbitrary fold indices
This commit is contained in:
Oliver Laslett
2018-06-30 20:23:49 +01:00
committed by Philip Hyunsu Cho
parent 594bcea83e
commit 18813a26ab
2 changed files with 92 additions and 9 deletions

View File

@@ -1,4 +1,12 @@
# -*- coding: utf-8 -*-
import sys
from contextlib import contextmanager
try:
# python 2
from StringIO import StringIO
except ImportError:
# python 3
from io import StringIO
import numpy as np
import xgboost as xgb
import unittest
@@ -8,6 +16,21 @@ dpath = 'demo/data/'
rng = np.random.RandomState(1994)
@contextmanager
def captured_output():
"""
Reassign stdout temporarily in order to test printed statements
Taken from: https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python
"""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err
class TestBasic(unittest.TestCase):
def test_basic(self):
@@ -238,3 +261,41 @@ class TestBasic(unittest.TestCase):
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
assert isinstance(cv, dict)
assert len(cv) == (4)
def test_cv_explicit_fold_indices(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
folds = [
# Train Test
([1, 3], [5, 8]),
([7, 9], [23, 43]),
]
# return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, folds=folds, as_pandas=False)
assert isinstance(cv, dict)
assert len(cv) == (4)
def test_cv_explicit_fold_indices_labels(self):
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'reg:linear'}
N = 100
F = 3
dm = xgb.DMatrix(data=np.random.randn(N, F), label=np.arange(N))
folds = [
# Train Test
([1, 3], [5, 8]),
([7, 9], [23, 43, 11]),
]
# Use callback to log the test labels in each fold
def cb(cbackenv):
print([fold.dtest.get_label() for fold in cbackenv.cvfolds])
# Run cross validation and capture standard out to test callback result
with captured_output() as (out, err):
xgb.cv(
params, dm, num_boost_round=1, folds=folds, callbacks=[cb],
as_pandas=False
)
output = out.getvalue().strip()
assert output == '[array([5., 8.], dtype=float32), array([23., 43., 11.], dtype=float32)]'