Fix period in evaluation monitor. (#6441)
This commit is contained in:
parent
8a0db293c5
commit
a2c778e2d1
@ -621,7 +621,7 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||||
msg += '\n'
|
msg += '\n'
|
||||||
|
|
||||||
if (epoch % self.period) != 0:
|
if (epoch % self.period) != 0 or self.period == 1:
|
||||||
rabit.tracker_print(msg)
|
rabit.tracker_print(msg)
|
||||||
self._latest = None
|
self._latest = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -22,6 +22,27 @@ class TestCallbacks:
|
|||||||
cls.X_valid = X[split:, ...]
|
cls.X_valid = X[split:, ...]
|
||||||
cls.y_valid = y[split:, ...]
|
cls.y_valid = y[split:, ...]
|
||||||
|
|
||||||
|
def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval):
|
||||||
|
evals_result = {}
|
||||||
|
with tm.captured_output() as (out, err):
|
||||||
|
xgb.train({'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'error'}, D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
num_boost_round=rounds,
|
||||||
|
evals_result=evals_result,
|
||||||
|
verbose_eval=verbose_eval)
|
||||||
|
output: str = out.getvalue().strip()
|
||||||
|
|
||||||
|
pos = 0
|
||||||
|
msg = 'Train-error'
|
||||||
|
for i in range(rounds // int(verbose_eval)):
|
||||||
|
pos = output.find('Train-error', pos)
|
||||||
|
assert pos != -1
|
||||||
|
pos += len(msg)
|
||||||
|
|
||||||
|
assert output.find('Train-error', pos) == -1
|
||||||
|
|
||||||
|
|
||||||
def test_evaluation_monitor(self):
|
def test_evaluation_monitor(self):
|
||||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||||
@ -36,23 +57,8 @@ class TestCallbacks:
|
|||||||
assert len(evals_result['Train']['error']) == rounds
|
assert len(evals_result['Train']['error']) == rounds
|
||||||
assert len(evals_result['Valid']['error']) == rounds
|
assert len(evals_result['Valid']['error']) == rounds
|
||||||
|
|
||||||
with tm.captured_output() as (out, err):
|
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
|
||||||
xgb.train({'objective': 'binary:logistic',
|
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
|
||||||
'eval_metric': 'error'}, D_train,
|
|
||||||
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
|
||||||
num_boost_round=rounds,
|
|
||||||
evals_result=evals_result,
|
|
||||||
verbose_eval=2)
|
|
||||||
output: str = out.getvalue().strip()
|
|
||||||
|
|
||||||
pos = 0
|
|
||||||
msg = 'Train-error'
|
|
||||||
for i in range(rounds // 2):
|
|
||||||
pos = output.find('Train-error', pos)
|
|
||||||
assert pos != -1
|
|
||||||
pos += len(msg)
|
|
||||||
|
|
||||||
assert output.find('Train-error', pos) == -1
|
|
||||||
|
|
||||||
def test_early_stopping(self):
|
def test_early_stopping(self):
|
||||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user