Fix handling of print period in EvaluationMonitor (#6499)

Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
ShvetsKS 2020-12-15 14:20:19 +03:00 committed by GitHub
parent 9a194273cd
commit 8139849ab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 11 deletions

View File

@ -621,7 +621,7 @@ class EvaluationMonitor(TrainingCallback):
msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n'
if (epoch % self.period) != 0 or self.period == 1:
if (epoch % self.period) == 0 or self.period == 1:
rabit.tracker_print(msg)
self._latest = None
else:

View File

@ -33,15 +33,18 @@ class TestCallbacks:
verbose_eval=verbose_eval)
output: str = out.getvalue().strip()
pos = 0
msg = 'Train-error'
for i in range(rounds // int(verbose_eval)):
pos = output.find('Train-error', pos)
assert pos != -1
pos += len(msg)
assert output.find('Train-error', pos) == -1
if int(verbose_eval) == 1:
# Should print each iteration info
assert len(output.split('\n')) == rounds
elif int(verbose_eval) > rounds:
# Should print first and latest iteration info
assert len(output.split('\n')) == 2
else:
# Should print info by each period additionaly to first and latest iteration
num_periods = rounds // int(verbose_eval)
# Extra information is required for latest iteration
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
def test_evaluation_monitor(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
@ -57,8 +60,10 @@ class TestCallbacks:
assert len(evals_result['Train']['error']) == rounds
assert len(evals_result['Valid']['error']) == rounds
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
def test_early_stopping(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)