Fix cv verbose_eval (#7291)

This commit is contained in:
Jiaming Yuan 2021-10-08 12:28:38 +08:00 committed by GitHub
parent f7caac2563
commit 578de9f762
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 27 deletions

View File

@ -472,13 +472,15 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
if is_new_callback: if is_new_callback:
assert all(isinstance(c, callback.TrainingCallback) assert all(isinstance(c, callback.TrainingCallback)
for c in callbacks), "You can't mix new and old callback styles." for c in callbacks), "You can't mix new and old callback styles."
if isinstance(verbose_eval, bool) and verbose_eval: if verbose_eval:
verbose_eval = 1 if verbose_eval is True else verbose_eval verbose_eval = 1 if verbose_eval is True else verbose_eval
callbacks.append(callback.EvaluationMonitor(period=verbose_eval, callbacks.append(
show_stdv=show_stdv)) callback.EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv)
)
if early_stopping_rounds: if early_stopping_rounds:
callbacks.append(callback.EarlyStopping( callbacks.append(
rounds=early_stopping_rounds, maximize=maximize)) callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
)
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True) callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
else: else:
callbacks = _configure_deprecated_callbacks( callbacks = _configure_deprecated_callbacks(

View File

@ -1,3 +1,4 @@
from typing import Union
import xgboost as xgb import xgboost as xgb
import pytest import pytest
import os import os
@ -22,29 +23,47 @@ class TestCallbacks:
cls.X_valid = X[split:, ...] cls.X_valid = X[split:, ...]
cls.y_valid = y[split:, ...] cls.y_valid = y[split:, ...]
def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval): def run_evaluation_monitor(
evals_result = {} self,
with tm.captured_output() as (out, err): D_train: xgb.DMatrix,
xgb.train({'objective': 'binary:logistic', D_valid: xgb.DMatrix,
'eval_metric': 'error'}, D_train, rounds: int,
evals=[(D_train, 'Train'), (D_valid, 'Valid')], verbose_eval: Union[bool, int]
num_boost_round=rounds, ):
evals_result=evals_result, def check_output(output: str) -> None:
verbose_eval=verbose_eval) if int(verbose_eval) == 1:
output: str = out.getvalue().strip() # Should print each iteration info
assert len(output.split('\n')) == rounds
elif int(verbose_eval) > rounds:
# Should print first and latest iteration info
assert len(output.split('\n')) == 2
else:
# Should print info by each period additionaly to first and latest
# iteration
num_periods = rounds // int(verbose_eval)
# Extra information is required for latest iteration
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
assert len(output.split('\n')) == (
1 + num_periods + int(is_extra_info_required)
)
if int(verbose_eval) == 1: evals_result: xgb.callback.TrainingCallback.EvalsLog = {}
# Should print each iteration info params = {'objective': 'binary:logistic', 'eval_metric': 'error'}
assert len(output.split('\n')) == rounds with tm.captured_output() as (out, err):
elif int(verbose_eval) > rounds: xgb.train(
# Should print first and latest iteration info params, D_train,
assert len(output.split('\n')) == 2 evals=[(D_train, 'Train'), (D_valid, 'Valid')],
else: num_boost_round=rounds,
# Should print info by each period additionaly to first and latest iteration evals_result=evals_result,
num_periods = rounds // int(verbose_eval) verbose_eval=verbose_eval,
# Extra information is required for latest iteration )
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1) output: str = out.getvalue().strip()
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required) check_output(output)
with tm.captured_output() as (out, err):
xgb.cv(params, D_train, num_boost_round=rounds, verbose_eval=verbose_eval)
output = out.getvalue().strip()
check_output(output)
def test_evaluation_monitor(self): def test_evaluation_monitor(self):
D_train = xgb.DMatrix(self.X_train, self.y_train) D_train = xgb.DMatrix(self.X_train, self.y_train)