allow arbitrary cross validation fold indices (#3353)
* allow arbitrary cross validation fold indices - use training indices passed to `folds` parameter in `training.cv` - update doc string * add tests for arbitrary fold indices
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
594bcea83e
commit
18813a26ab
@@ -1,4 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
try:
|
||||
# python 2
|
||||
from StringIO import StringIO
|
||||
except ImportError:
|
||||
# python 3
|
||||
from io import StringIO
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
@@ -8,6 +16,21 @@ dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def captured_output():
|
||||
"""
|
||||
Reassign stdout temporarily in order to test printed statements
|
||||
Taken from: https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python
|
||||
"""
|
||||
new_out, new_err = StringIO(), StringIO()
|
||||
old_out, old_err = sys.stdout, sys.stderr
|
||||
try:
|
||||
sys.stdout, sys.stderr = new_out, new_err
|
||||
yield sys.stdout, sys.stderr
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_out, old_err
|
||||
|
||||
|
||||
class TestBasic(unittest.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
@@ -238,3 +261,41 @@ class TestBasic(unittest.TestCase):
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == (4)
|
||||
|
||||
def test_cv_explicit_fold_indices(self):
|
||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
folds = [
|
||||
# Train Test
|
||||
([1, 3], [5, 8]),
|
||||
([7, 9], [23, 43]),
|
||||
]
|
||||
|
||||
# return np.ndarray
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, folds=folds, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == (4)
|
||||
|
||||
def test_cv_explicit_fold_indices_labels(self):
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'reg:linear'}
|
||||
N = 100
|
||||
F = 3
|
||||
dm = xgb.DMatrix(data=np.random.randn(N, F), label=np.arange(N))
|
||||
folds = [
|
||||
# Train Test
|
||||
([1, 3], [5, 8]),
|
||||
([7, 9], [23, 43, 11]),
|
||||
]
|
||||
|
||||
# Use callback to log the test labels in each fold
|
||||
def cb(cbackenv):
|
||||
print([fold.dtest.get_label() for fold in cbackenv.cvfolds])
|
||||
|
||||
# Run cross validation and capture standard out to test callback result
|
||||
with captured_output() as (out, err):
|
||||
xgb.cv(
|
||||
params, dm, num_boost_round=1, folds=folds, callbacks=[cb],
|
||||
as_pandas=False
|
||||
)
|
||||
output = out.getvalue().strip()
|
||||
assert output == '[array([5., 8.], dtype=float32), array([23., 43., 11.], dtype=float32)]'
|
||||
|
||||
Reference in New Issue
Block a user