parent
cdbfd21d31
commit
c4aff733bb
@ -472,13 +472,15 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
if is_new_callback:
|
if is_new_callback:
|
||||||
assert all(isinstance(c, callback.TrainingCallback)
|
assert all(isinstance(c, callback.TrainingCallback)
|
||||||
for c in callbacks), "You can't mix new and old callback styles."
|
for c in callbacks), "You can't mix new and old callback styles."
|
||||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
if verbose_eval:
|
||||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||||
callbacks.append(callback.EvaluationMonitor(period=verbose_eval,
|
callbacks.append(
|
||||||
show_stdv=show_stdv))
|
callback.EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv)
|
||||||
|
)
|
||||||
if early_stopping_rounds:
|
if early_stopping_rounds:
|
||||||
callbacks.append(callback.EarlyStopping(
|
callbacks.append(
|
||||||
rounds=early_stopping_rounds, maximize=maximize))
|
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||||
|
)
|
||||||
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
|
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
|
||||||
else:
|
else:
|
||||||
callbacks = _configure_deprecated_callbacks(
|
callbacks = _configure_deprecated_callbacks(
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from typing import Union
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
@ -22,29 +23,47 @@ class TestCallbacks:
|
|||||||
cls.X_valid = X[split:, ...]
|
cls.X_valid = X[split:, ...]
|
||||||
cls.y_valid = y[split:, ...]
|
cls.y_valid = y[split:, ...]
|
||||||
|
|
||||||
def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval):
|
def run_evaluation_monitor(
|
||||||
evals_result = {}
|
self,
|
||||||
with tm.captured_output() as (out, err):
|
D_train: xgb.DMatrix,
|
||||||
xgb.train({'objective': 'binary:logistic',
|
D_valid: xgb.DMatrix,
|
||||||
'eval_metric': 'error'}, D_train,
|
rounds: int,
|
||||||
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
verbose_eval: Union[bool, int]
|
||||||
num_boost_round=rounds,
|
):
|
||||||
evals_result=evals_result,
|
def check_output(output: str) -> None:
|
||||||
verbose_eval=verbose_eval)
|
if int(verbose_eval) == 1:
|
||||||
output: str = out.getvalue().strip()
|
# Should print each iteration info
|
||||||
|
assert len(output.split('\n')) == rounds
|
||||||
|
elif int(verbose_eval) > rounds:
|
||||||
|
# Should print first and latest iteration info
|
||||||
|
assert len(output.split('\n')) == 2
|
||||||
|
else:
|
||||||
|
# Should print info by each period additionaly to first and latest
|
||||||
|
# iteration
|
||||||
|
num_periods = rounds // int(verbose_eval)
|
||||||
|
# Extra information is required for latest iteration
|
||||||
|
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
|
||||||
|
assert len(output.split('\n')) == (
|
||||||
|
1 + num_periods + int(is_extra_info_required)
|
||||||
|
)
|
||||||
|
|
||||||
if int(verbose_eval) == 1:
|
evals_result: xgb.callback.TrainingCallback.EvalsLog = {}
|
||||||
# Should print each iteration info
|
params = {'objective': 'binary:logistic', 'eval_metric': 'error'}
|
||||||
assert len(output.split('\n')) == rounds
|
with tm.captured_output() as (out, err):
|
||||||
elif int(verbose_eval) > rounds:
|
xgb.train(
|
||||||
# Should print first and latest iteration info
|
params, D_train,
|
||||||
assert len(output.split('\n')) == 2
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
else:
|
num_boost_round=rounds,
|
||||||
# Should print info by each period additionaly to first and latest iteration
|
evals_result=evals_result,
|
||||||
num_periods = rounds // int(verbose_eval)
|
verbose_eval=verbose_eval,
|
||||||
# Extra information is required for latest iteration
|
)
|
||||||
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
|
output: str = out.getvalue().strip()
|
||||||
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
|
check_output(output)
|
||||||
|
|
||||||
|
with tm.captured_output() as (out, err):
|
||||||
|
xgb.cv(params, D_train, num_boost_round=rounds, verbose_eval=verbose_eval)
|
||||||
|
output = out.getvalue().strip()
|
||||||
|
check_output(output)
|
||||||
|
|
||||||
def test_evaluation_monitor(self):
|
def test_evaluation_monitor(self):
|
||||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user