Fix handling of print period in EvaluationMonitor (#6499)
Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
parent
9a194273cd
commit
8139849ab6
@ -621,7 +621,7 @@ class EvaluationMonitor(TrainingCallback):
|
||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||
msg += '\n'
|
||||
|
||||
if (epoch % self.period) != 0 or self.period == 1:
|
||||
if (epoch % self.period) == 0 or self.period == 1:
|
||||
rabit.tracker_print(msg)
|
||||
self._latest = None
|
||||
else:
|
||||
|
||||
@ -33,15 +33,18 @@ class TestCallbacks:
|
||||
verbose_eval=verbose_eval)
|
||||
output: str = out.getvalue().strip()
|
||||
|
||||
pos = 0
|
||||
msg = 'Train-error'
|
||||
for i in range(rounds // int(verbose_eval)):
|
||||
pos = output.find('Train-error', pos)
|
||||
assert pos != -1
|
||||
pos += len(msg)
|
||||
|
||||
assert output.find('Train-error', pos) == -1
|
||||
|
||||
if int(verbose_eval) == 1:
|
||||
# Should print each iteration info
|
||||
assert len(output.split('\n')) == rounds
|
||||
elif int(verbose_eval) > rounds:
|
||||
# Should print first and latest iteration info
|
||||
assert len(output.split('\n')) == 2
|
||||
else:
|
||||
# Should print info by each period additionaly to first and latest iteration
|
||||
num_periods = rounds // int(verbose_eval)
|
||||
# Extra information is required for latest iteration
|
||||
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
|
||||
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
|
||||
|
||||
def test_evaluation_monitor(self):
|
||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||
@ -57,8 +60,10 @@ class TestCallbacks:
|
||||
assert len(evals_result['Train']['error']) == rounds
|
||||
assert len(evals_result['Valid']['error']) == rounds
|
||||
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
|
||||
|
||||
def test_early_stopping(self):
|
||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user