Fix handling of print period in EvaluationMonitor (#6499)

Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
ShvetsKS 2020-12-15 14:20:19 +03:00 committed by GitHub
parent 9a194273cd
commit 8139849ab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 11 deletions

View File

@ -621,7 +621,7 @@ class EvaluationMonitor(TrainingCallback):
msg += self._fmt_metric(data, metric_name, score, stdv) msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n' msg += '\n'
if (epoch % self.period) != 0 or self.period == 1: if (epoch % self.period) == 0 or self.period == 1:
rabit.tracker_print(msg) rabit.tracker_print(msg)
self._latest = None self._latest = None
else: else:

View File

@ -33,15 +33,18 @@ class TestCallbacks:
verbose_eval=verbose_eval) verbose_eval=verbose_eval)
output: str = out.getvalue().strip() output: str = out.getvalue().strip()
pos = 0 if int(verbose_eval) == 1:
msg = 'Train-error' # Should print each iteration info
for i in range(rounds // int(verbose_eval)): assert len(output.split('\n')) == rounds
pos = output.find('Train-error', pos) elif int(verbose_eval) > rounds:
assert pos != -1 # Should print first and latest iteration info
pos += len(msg) assert len(output.split('\n')) == 2
else:
assert output.find('Train-error', pos) == -1 # Should print info by each period additionaly to first and latest iteration
num_periods = rounds // int(verbose_eval)
# Extra information is required for latest iteration
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
def test_evaluation_monitor(self): def test_evaluation_monitor(self):
D_train = xgb.DMatrix(self.X_train, self.y_train) D_train = xgb.DMatrix(self.X_train, self.y_train)
@ -57,8 +60,10 @@ class TestCallbacks:
assert len(evals_result['Train']['error']) == rounds assert len(evals_result['Train']['error']) == rounds
assert len(evals_result['Valid']['error']) == rounds assert len(evals_result['Valid']['error']) == rounds
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, True) self.run_evaluation_monitor(D_train, D_valid, rounds, True)
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
def test_early_stopping(self): def test_early_stopping(self):
D_train = xgb.DMatrix(self.X_train, self.y_train) D_train = xgb.DMatrix(self.X_train, self.y_train)