Fix period in evaluation monitor. (#6441)

This commit is contained in:
Jiaming Yuan 2020-11-28 14:18:33 -05:00 committed by Hyunsu Cho
parent 8a0db293c5
commit a2c778e2d1
2 changed files with 24 additions and 18 deletions

View File

@ -621,7 +621,7 @@ class EvaluationMonitor(TrainingCallback):
msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n'
if (epoch % self.period) != 0:
if (epoch % self.period) != 0 or self.period == 1:
rabit.tracker_print(msg)
self._latest = None
else:

View File

@ -22,6 +22,27 @@ class TestCallbacks:
cls.X_valid = X[split:, ...]
cls.y_valid = y[split:, ...]
def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval):
evals_result = {}
with tm.captured_output() as (out, err):
xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=rounds,
evals_result=evals_result,
verbose_eval=verbose_eval)
output: str = out.getvalue().strip()
pos = 0
msg = 'Train-error'
for i in range(rounds // int(verbose_eval)):
pos = output.find('Train-error', pos)
assert pos != -1
pos += len(msg)
assert output.find('Train-error', pos) == -1
def test_evaluation_monitor(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
@ -36,23 +57,8 @@ class TestCallbacks:
assert len(evals_result['Train']['error']) == rounds
assert len(evals_result['Valid']['error']) == rounds
with tm.captured_output() as (out, err):
xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=rounds,
evals_result=evals_result,
verbose_eval=2)
output: str = out.getvalue().strip()
pos = 0
msg = 'Train-error'
for i in range(rounds // 2):
pos = output.find('Train-error', pos)
assert pos != -1
pos += len(msg)
assert output.find('Train-error', pos) == -1
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
def test_early_stopping(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)