Fix handling of print period in EvaluationMonitor (#6499)
Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
parent
9a194273cd
commit
8139849ab6
@ -621,7 +621,7 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||||
msg += '\n'
|
msg += '\n'
|
||||||
|
|
||||||
if (epoch % self.period) != 0 or self.period == 1:
|
if (epoch % self.period) == 0 or self.period == 1:
|
||||||
rabit.tracker_print(msg)
|
rabit.tracker_print(msg)
|
||||||
self._latest = None
|
self._latest = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -33,15 +33,18 @@ class TestCallbacks:
|
|||||||
verbose_eval=verbose_eval)
|
verbose_eval=verbose_eval)
|
||||||
output: str = out.getvalue().strip()
|
output: str = out.getvalue().strip()
|
||||||
|
|
||||||
pos = 0
|
if int(verbose_eval) == 1:
|
||||||
msg = 'Train-error'
|
# Should print each iteration info
|
||||||
for i in range(rounds // int(verbose_eval)):
|
assert len(output.split('\n')) == rounds
|
||||||
pos = output.find('Train-error', pos)
|
elif int(verbose_eval) > rounds:
|
||||||
assert pos != -1
|
# Should print first and latest iteration info
|
||||||
pos += len(msg)
|
assert len(output.split('\n')) == 2
|
||||||
|
else:
|
||||||
assert output.find('Train-error', pos) == -1
|
# Should print info by each period additionaly to first and latest iteration
|
||||||
|
num_periods = rounds // int(verbose_eval)
|
||||||
|
# Extra information is required for latest iteration
|
||||||
|
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
|
||||||
|
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
|
||||||
|
|
||||||
def test_evaluation_monitor(self):
|
def test_evaluation_monitor(self):
|
||||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
@ -57,8 +60,10 @@ class TestCallbacks:
|
|||||||
assert len(evals_result['Train']['error']) == rounds
|
assert len(evals_result['Train']['error']) == rounds
|
||||||
assert len(evals_result['Valid']['error']) == rounds
|
assert len(evals_result['Valid']['error']) == rounds
|
||||||
|
|
||||||
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
|
|
||||||
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
|
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
|
||||||
|
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
|
||||||
|
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
|
||||||
|
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
|
||||||
|
|
||||||
def test_early_stopping(self):
|
def test_early_stopping(self):
|
||||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user