Add period to evaluation monitor. (#6348)

This commit is contained in:
Jiaming Yuan
2020-11-10 07:47:48 +08:00
committed by GitHub
parent d411f98d26
commit 184e2eac7d
8 changed files with 72 additions and 40 deletions

View File

@@ -583,12 +583,18 @@ class EvaluationMonitor(TrainingCallback):
Extra user defined metric.
rank : int
Which worker should be used for printing the result.
period : int
How many epoches between printing.
show_stdv : bool
Used in cv to show standard deviation. Users should not specify it.
'''
def __init__(self, rank=0, show_stdv=False):
def __init__(self, rank=0, period=1, show_stdv=False):
self.printer_rank = rank
self.show_stdv = show_stdv
self.period = period
assert period > 0
# last error message, useful when early stopping and period are used together.
self._lastest = None
super().__init__()
def _fmt_metric(self, data, metric, score, std):
@@ -601,6 +607,7 @@ class EvaluationMonitor(TrainingCallback):
def after_iteration(self, model, epoch, evals_log):
if not evals_log:
return False
msg = f'[{epoch}]'
if rabit.get_rank() == self.printer_rank:
for data, metric in evals_log.items():
@@ -613,9 +620,20 @@ class EvaluationMonitor(TrainingCallback):
stdv = None
msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n'
rabit.tracker_print(msg)
if (epoch % self.period) != 0:
rabit.tracker_print(msg)
self._lastest = None
else:
# There is skipped message
self._lastest = msg
return False
def after_training(self, model):
if rabit.get_rank() == self.printer_rank and self._lastest is not None:
rabit.tracker_print(self._lastest)
return model
class TrainingCheckPoint(TrainingCallback):
'''Checkpointing operation.

View File

@@ -92,7 +92,8 @@ def _train_internal(params, dtrain,
assert all(isinstance(c, callback.TrainingCallback)
for c in callbacks), "You can't mix new and old callback styles."
if verbose_eval:
callbacks.append(callback.EvaluationMonitor())
verbose_eval = 1 if verbose_eval is True else verbose_eval
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
if early_stopping_rounds:
callbacks.append(callback.EarlyStopping(
rounds=early_stopping_rounds, maximize=maximize))
@@ -485,7 +486,9 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
assert all(isinstance(c, callback.TrainingCallback)
for c in callbacks), "You can't mix new and old callback styles."
if isinstance(verbose_eval, bool) and verbose_eval:
callbacks.append(callback.EvaluationMonitor(show_stdv=show_stdv))
verbose_eval = 1 if verbose_eval is True else verbose_eval
callbacks.append(callback.EvaluationMonitor(period=verbose_eval,
show_stdv=show_stdv))
if early_stopping_rounds:
callbacks.append(callback.EarlyStopping(
rounds=early_stopping_rounds, maximize=maximize))