Add period to evaluation monitor. (#6348)
This commit is contained in:
@@ -583,12 +583,18 @@ class EvaluationMonitor(TrainingCallback):
|
||||
Extra user defined metric.
|
||||
rank : int
|
||||
Which worker should be used for printing the result.
|
||||
period : int
|
||||
How many epoches between printing.
|
||||
show_stdv : bool
|
||||
Used in cv to show standard deviation. Users should not specify it.
|
||||
'''
|
||||
def __init__(self, rank=0, show_stdv=False):
|
||||
def __init__(self, rank=0, period=1, show_stdv=False):
|
||||
self.printer_rank = rank
|
||||
self.show_stdv = show_stdv
|
||||
self.period = period
|
||||
assert period > 0
|
||||
# last error message, useful when early stopping and period are used together.
|
||||
self._lastest = None
|
||||
super().__init__()
|
||||
|
||||
def _fmt_metric(self, data, metric, score, std):
|
||||
@@ -601,6 +607,7 @@ class EvaluationMonitor(TrainingCallback):
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
if not evals_log:
|
||||
return False
|
||||
|
||||
msg = f'[{epoch}]'
|
||||
if rabit.get_rank() == self.printer_rank:
|
||||
for data, metric in evals_log.items():
|
||||
@@ -613,9 +620,20 @@ class EvaluationMonitor(TrainingCallback):
|
||||
stdv = None
|
||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||
msg += '\n'
|
||||
rabit.tracker_print(msg)
|
||||
|
||||
if (epoch % self.period) != 0:
|
||||
rabit.tracker_print(msg)
|
||||
self._lastest = None
|
||||
else:
|
||||
# There is skipped message
|
||||
self._lastest = msg
|
||||
return False
|
||||
|
||||
def after_training(self, model):
|
||||
if rabit.get_rank() == self.printer_rank and self._lastest is not None:
|
||||
rabit.tracker_print(self._lastest)
|
||||
return model
|
||||
|
||||
|
||||
class TrainingCheckPoint(TrainingCallback):
|
||||
'''Checkpointing operation.
|
||||
|
||||
@@ -92,7 +92,8 @@ def _train_internal(params, dtrain,
|
||||
assert all(isinstance(c, callback.TrainingCallback)
|
||||
for c in callbacks), "You can't mix new and old callback styles."
|
||||
if verbose_eval:
|
||||
callbacks.append(callback.EvaluationMonitor())
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(callback.EarlyStopping(
|
||||
rounds=early_stopping_rounds, maximize=maximize))
|
||||
@@ -485,7 +486,9 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
assert all(isinstance(c, callback.TrainingCallback)
|
||||
for c in callbacks), "You can't mix new and old callback styles."
|
||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||
callbacks.append(callback.EvaluationMonitor(show_stdv=show_stdv))
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(callback.EvaluationMonitor(period=verbose_eval,
|
||||
show_stdv=show_stdv))
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(callback.EarlyStopping(
|
||||
rounds=early_stopping_rounds, maximize=maximize))
|
||||
|
||||
Reference in New Issue
Block a user