Fix period in evaluation monitor. (#6441)
This commit is contained in:
parent
8a0db293c5
commit
a2c778e2d1
@ -621,7 +621,7 @@ class EvaluationMonitor(TrainingCallback):
|
||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||
msg += '\n'
|
||||
|
||||
if (epoch % self.period) != 0:
|
||||
if (epoch % self.period) != 0 or self.period == 1:
|
||||
rabit.tracker_print(msg)
|
||||
self._latest = None
|
||||
else:
|
||||
|
||||
@ -22,6 +22,27 @@ class TestCallbacks:
|
||||
cls.X_valid = X[split:, ...]
|
||||
cls.y_valid = y[split:, ...]
|
||||
|
||||
def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval):
|
||||
evals_result = {}
|
||||
with tm.captured_output() as (out, err):
|
||||
xgb.train({'objective': 'binary:logistic',
|
||||
'eval_metric': 'error'}, D_train,
|
||||
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||
num_boost_round=rounds,
|
||||
evals_result=evals_result,
|
||||
verbose_eval=verbose_eval)
|
||||
output: str = out.getvalue().strip()
|
||||
|
||||
pos = 0
|
||||
msg = 'Train-error'
|
||||
for i in range(rounds // int(verbose_eval)):
|
||||
pos = output.find('Train-error', pos)
|
||||
assert pos != -1
|
||||
pos += len(msg)
|
||||
|
||||
assert output.find('Train-error', pos) == -1
|
||||
|
||||
|
||||
def test_evaluation_monitor(self):
|
||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||
@ -36,23 +57,8 @@ class TestCallbacks:
|
||||
assert len(evals_result['Train']['error']) == rounds
|
||||
assert len(evals_result['Valid']['error']) == rounds
|
||||
|
||||
with tm.captured_output() as (out, err):
|
||||
xgb.train({'objective': 'binary:logistic',
|
||||
'eval_metric': 'error'}, D_train,
|
||||
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||
num_boost_round=rounds,
|
||||
evals_result=evals_result,
|
||||
verbose_eval=2)
|
||||
output: str = out.getvalue().strip()
|
||||
|
||||
pos = 0
|
||||
msg = 'Train-error'
|
||||
for i in range(rounds // 2):
|
||||
pos = output.find('Train-error', pos)
|
||||
assert pos != -1
|
||||
pos += len(msg)
|
||||
|
||||
assert output.find('Train-error', pos) == -1
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
|
||||
|
||||
def test_early_stopping(self):
|
||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user