From a2c778e2d12246cbc1da5b692a7cae23026693ce Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 28 Nov 2020 14:18:33 -0500 Subject: [PATCH] Fix period in evaluation monitor. (#6441) --- python-package/xgboost/callback.py | 2 +- tests/python/test_callback.py | 40 +++++++++++++++++------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index b9583c381..e15bf699f 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -621,7 +621,7 @@ class EvaluationMonitor(TrainingCallback): msg += self._fmt_metric(data, metric_name, score, stdv) msg += '\n' - if (epoch % self.period) != 0: + if (epoch % self.period) != 0 or self.period == 1: rabit.tracker_print(msg) self._latest = None else: diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index e1c5feee0..bdae94f87 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -22,6 +22,27 @@ class TestCallbacks: cls.X_valid = X[split:, ...] cls.y_valid = y[split:, ...] + def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval): + evals_result = {} + with tm.captured_output() as (out, err): + xgb.train({'objective': 'binary:logistic', + 'eval_metric': 'error'}, D_train, + evals=[(D_train, 'Train'), (D_valid, 'Valid')], + num_boost_round=rounds, + evals_result=evals_result, + verbose_eval=verbose_eval) + output: str = out.getvalue().strip() + + pos = 0 + msg = 'Train-error' + for i in range(rounds // int(verbose_eval)): + pos = output.find('Train-error', pos) + assert pos != -1 + pos += len(msg) + + assert output.find('Train-error', pos) == -1 + + def test_evaluation_monitor(self): D_train = xgb.DMatrix(self.X_train, self.y_train) D_valid = xgb.DMatrix(self.X_valid, self.y_valid) @@ -36,23 +57,8 @@ class TestCallbacks: assert len(evals_result['Train']['error']) == rounds assert len(evals_result['Valid']['error']) == rounds - with tm.captured_output() as (out, err): - xgb.train({'objective': 'binary:logistic', - 'eval_metric': 'error'}, D_train, - evals=[(D_train, 'Train'), (D_valid, 'Valid')], - num_boost_round=rounds, - evals_result=evals_result, - verbose_eval=2) - output: str = out.getvalue().strip() - - pos = 0 - msg = 'Train-error' - for i in range(rounds // 2): - pos = output.find('Train-error', pos) - assert pos != -1 - pos += len(msg) - - assert output.find('Train-error', pos) == -1 + self.run_evaluation_monitor(D_train, D_valid, rounds, 2) + self.run_evaluation_monitor(D_train, D_valid, rounds, True) def test_early_stopping(self): D_train = xgb.DMatrix(self.X_train, self.y_train)