832 lines
28 KiB
Python
832 lines
28 KiB
Python
# coding: utf-8
|
|
# pylint: disable=invalid-name, too-many-statements, no-self-use
|
|
# pylint: disable=too-many-arguments
|
|
"""Training Library containing training routines."""
|
|
from abc import ABC
|
|
import collections
|
|
import os
|
|
import pickle
|
|
from typing import Callable, List, Optional, Union, Dict, Tuple
|
|
import numpy
|
|
|
|
from . import rabit
|
|
from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError
|
|
from .compat import STRING_TYPES
|
|
|
|
|
|
def _get_callback_context(env):
|
|
"""return whether the current callback context is cv or train"""
|
|
if env.model is not None and env.cvfolds is None:
|
|
context = 'train'
|
|
elif env.model is None and env.cvfolds is not None:
|
|
context = 'cv'
|
|
else:
|
|
raise ValueError("Unexpected input with both model and cvfolds.")
|
|
return context
|
|
|
|
|
|
def _fmt_metric(value, show_stdv=True):
|
|
"""format metric string"""
|
|
if len(value) == 2:
|
|
return '{0}:{1:.5f}'.format(value[0], value[1])
|
|
if len(value) == 3:
|
|
if show_stdv:
|
|
return '{0}:{1:.5f}+{2:.5f}'.format(value[0], value[1], value[2])
|
|
return '{0}:{1:.5f}'.format(value[0], value[1])
|
|
raise ValueError("wrong metric value", value)
|
|
|
|
|
|
def print_evaluation(period=1, show_stdv=True):
|
|
"""Create a callback that print evaluation result.
|
|
|
|
We print the evaluation results every **period** iterations
|
|
and on the first and the last iterations.
|
|
|
|
Parameters
|
|
----------
|
|
period : int
|
|
The period to log the evaluation results
|
|
|
|
show_stdv : bool, optional
|
|
Whether show stdv if provided
|
|
|
|
Returns
|
|
-------
|
|
callback : function
|
|
A callback that print evaluation every period iterations.
|
|
"""
|
|
def callback(env):
|
|
"""internal function"""
|
|
if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0:
|
|
return
|
|
i = env.iteration
|
|
if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
|
|
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
|
|
rabit.tracker_print('[%d]\t%s\n' % (i, msg))
|
|
return callback
|
|
|
|
|
|
def record_evaluation(eval_result):
|
|
"""Create a call back that records the evaluation history into **eval_result**.
|
|
|
|
Parameters
|
|
----------
|
|
eval_result : dict
|
|
A dictionary to store the evaluation results.
|
|
|
|
Returns
|
|
-------
|
|
callback : function
|
|
The requested callback function.
|
|
"""
|
|
if not isinstance(eval_result, dict):
|
|
raise TypeError('eval_result has to be a dictionary')
|
|
eval_result.clear()
|
|
|
|
def init(env):
|
|
"""internal function"""
|
|
for k, _ in env.evaluation_result_list:
|
|
pos = k.index('-')
|
|
key = k[:pos]
|
|
metric = k[pos + 1:]
|
|
if key not in eval_result:
|
|
eval_result[key] = {}
|
|
if metric not in eval_result[key]:
|
|
eval_result[key][metric] = []
|
|
|
|
def callback(env):
|
|
"""internal function"""
|
|
if not eval_result:
|
|
init(env)
|
|
for k, v in env.evaluation_result_list:
|
|
pos = k.index('-')
|
|
key = k[:pos]
|
|
metric = k[pos + 1:]
|
|
eval_result[key][metric].append(v)
|
|
return callback
|
|
|
|
|
|
def reset_learning_rate(learning_rates):
|
|
"""Reset learning rate after iteration 1
|
|
|
|
NOTE: the initial learning rate will still take in-effect on first iteration.
|
|
|
|
Parameters
|
|
----------
|
|
learning_rates: list or function
|
|
List of learning rate for each boosting round
|
|
or a customized function that calculates eta in terms of
|
|
current number of round and the total number of boosting round (e.g.
|
|
yields learning rate decay)
|
|
|
|
* list ``l``: ``eta = l[boosting_round]``
|
|
* function ``f``: ``eta = f(boosting_round, num_boost_round)``
|
|
|
|
Returns
|
|
-------
|
|
callback : function
|
|
The requested callback function.
|
|
"""
|
|
def get_learning_rate(i, n, learning_rates):
|
|
"""helper providing the learning rate"""
|
|
if isinstance(learning_rates, list):
|
|
if len(learning_rates) != n:
|
|
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
|
new_learning_rate = learning_rates[i]
|
|
else:
|
|
new_learning_rate = learning_rates(i, n)
|
|
return new_learning_rate
|
|
|
|
def callback(env):
|
|
"""internal function"""
|
|
context = _get_callback_context(env)
|
|
|
|
if context == 'train':
|
|
bst, i, n = env.model, env.iteration, env.end_iteration
|
|
bst.set_param(
|
|
'learning_rate', get_learning_rate(i, n, learning_rates))
|
|
elif context == 'cv':
|
|
i, n = env.iteration, env.end_iteration
|
|
for cvpack in env.cvfolds:
|
|
bst = cvpack.bst
|
|
bst.set_param(
|
|
'learning_rate', get_learning_rate(i, n, learning_rates))
|
|
|
|
callback.before_iteration = False
|
|
return callback
|
|
|
|
|
|
def early_stop(stopping_rounds, maximize=False, verbose=True):
|
|
"""Create a callback that activates early stoppping.
|
|
|
|
Validation error needs to decrease at least
|
|
every **stopping_rounds** round(s) to continue training.
|
|
Requires at least one item in **evals**.
|
|
If there's more than one, will use the last.
|
|
Returns the model from the last iteration (not the best one).
|
|
If early stopping occurs, the model will have three additional fields:
|
|
``bst.best_score``, ``bst.best_iteration``.
|
|
|
|
Parameters
|
|
----------
|
|
stopping_rounds : int
|
|
The stopping rounds before the trend occur.
|
|
|
|
maximize : bool
|
|
Whether to maximize evaluation metric.
|
|
|
|
verbose : optional, bool
|
|
Whether to print message about early stopping information.
|
|
|
|
Returns
|
|
-------
|
|
callback : function
|
|
The requested callback function.
|
|
"""
|
|
state = {}
|
|
|
|
def init(env):
|
|
"""internal function"""
|
|
bst = env.model
|
|
|
|
if not env.evaluation_result_list:
|
|
raise ValueError('For early stopping you need at least one set in evals.')
|
|
if len(env.evaluation_result_list) > 1 and verbose:
|
|
msg = ("Multiple eval metrics have been passed: "
|
|
"'{0}' will be used for early stopping.\n\n")
|
|
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
|
|
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg')
|
|
maximize_at_n_metrics = ('auc@', 'aucpr@', 'map@', 'ndcg@')
|
|
maximize_score = maximize
|
|
metric_label = env.evaluation_result_list[-1][0]
|
|
metric = metric_label.split('-', 1)[-1]
|
|
|
|
if any(metric.startswith(x) for x in maximize_at_n_metrics):
|
|
maximize_score = True
|
|
|
|
if any(metric.split(":")[0] == x for x in maximize_metrics):
|
|
maximize_score = True
|
|
|
|
if verbose and env.rank == 0:
|
|
msg = "Will train until {} hasn't improved in {} rounds.\n"
|
|
rabit.tracker_print(msg.format(metric_label, stopping_rounds))
|
|
|
|
state['maximize_score'] = maximize_score
|
|
state['best_iteration'] = 0
|
|
if maximize_score:
|
|
state['best_score'] = float('-inf')
|
|
else:
|
|
state['best_score'] = float('inf')
|
|
msg = '[%d]\t%s' % (
|
|
env.iteration,
|
|
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
|
|
state['best_msg'] = msg
|
|
|
|
if bst is not None:
|
|
if bst.attr('best_score') is not None:
|
|
state['best_score'] = float(bst.attr('best_score'))
|
|
state['best_iteration'] = int(bst.attr('best_iteration'))
|
|
state['best_msg'] = bst.attr('best_msg')
|
|
else:
|
|
bst.set_attr(best_iteration=str(state['best_iteration']))
|
|
bst.set_attr(best_score=str(state['best_score']))
|
|
else:
|
|
assert env.cvfolds is not None
|
|
|
|
def callback(env):
|
|
"""internal function"""
|
|
if not state:
|
|
init(env)
|
|
score = env.evaluation_result_list[-1][1]
|
|
best_score = state['best_score']
|
|
best_iteration = state['best_iteration']
|
|
maximize_score = state['maximize_score']
|
|
if (maximize_score and score > best_score) or \
|
|
(not maximize_score and score < best_score):
|
|
msg = '[%d]\t%s' % (
|
|
env.iteration,
|
|
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
|
|
state['best_msg'] = msg
|
|
state['best_score'] = score
|
|
state['best_iteration'] = env.iteration
|
|
# save the property to attributes, so they will occur in checkpoint.
|
|
if env.model is not None:
|
|
env.model.set_attr(best_score=str(state['best_score']),
|
|
best_iteration=str(state['best_iteration']),
|
|
best_msg=state['best_msg'])
|
|
elif env.iteration - best_iteration >= stopping_rounds:
|
|
best_msg = state['best_msg']
|
|
if verbose and env.rank == 0:
|
|
msg = "Stopping. Best iteration:\n{}\n\n"
|
|
rabit.tracker_print(msg.format(best_msg))
|
|
raise EarlyStopException(best_iteration)
|
|
return callback
|
|
|
|
|
|
# The new implementation of callback functions.
|
|
# Breaking:
|
|
# - reset learning rate no longer accepts total boosting rounds
|
|
|
|
# pylint: disable=unused-argument
|
|
class TrainingCallback(ABC):
|
|
'''Interface for training callback.
|
|
|
|
.. versionadded:: 1.3.0
|
|
|
|
'''
|
|
|
|
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def before_training(self, model):
|
|
'''Run before training starts.'''
|
|
return model
|
|
|
|
def after_training(self, model):
|
|
'''Run after training is finished.'''
|
|
return model
|
|
|
|
def before_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
|
|
'''Run before each iteration. Return True when training should stop.'''
|
|
return False
|
|
|
|
def after_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
|
|
'''Run after each iteration. Return True when training should stop.'''
|
|
return False
|
|
|
|
|
|
def _aggcv(rlist):
|
|
# pylint: disable=invalid-name
|
|
"""Aggregate cross-validation results.
|
|
|
|
"""
|
|
cvmap = {}
|
|
idx = rlist[0].split()[0]
|
|
for line in rlist:
|
|
arr = line.split()
|
|
assert idx == arr[0]
|
|
for metric_idx, it in enumerate(arr[1:]):
|
|
if not isinstance(it, STRING_TYPES):
|
|
it = it.decode()
|
|
k, v = it.split(':')
|
|
if (metric_idx, k) not in cvmap:
|
|
cvmap[(metric_idx, k)] = []
|
|
cvmap[(metric_idx, k)].append(float(v))
|
|
msg = idx
|
|
results = []
|
|
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
|
v = numpy.array(v)
|
|
if not isinstance(msg, STRING_TYPES):
|
|
msg = msg.decode()
|
|
mean, std = numpy.mean(v), numpy.std(v)
|
|
results.extend([(k, mean, std)])
|
|
return results
|
|
|
|
|
|
def _allreduce_metric(score):
|
|
'''Helper function for computing customized metric in distributed
|
|
environment. Not strictly correct as many functions don't use mean value
|
|
as final result.
|
|
|
|
'''
|
|
world = rabit.get_world_size()
|
|
assert world != 0
|
|
if world == 1:
|
|
return score
|
|
if isinstance(score, tuple): # has mean and stdv
|
|
raise ValueError(
|
|
'xgboost.cv function should not be used in distributed environment.')
|
|
score = numpy.array([score])
|
|
score = rabit.allreduce(score, rabit.Op.SUM) / world
|
|
return score[0]
|
|
|
|
|
|
class CallbackContainer:
|
|
'''A special callback for invoking a list of other callbacks.
|
|
|
|
.. versionadded:: 1.3.0
|
|
|
|
'''
|
|
|
|
EvalsLog = TrainingCallback.EvalsLog
|
|
|
|
def __init__(self,
|
|
callbacks: List[TrainingCallback],
|
|
metric: Callable = None,
|
|
is_cv: bool = False):
|
|
self.callbacks = set(callbacks)
|
|
if metric is not None:
|
|
msg = 'metric must be callable object for monitoring. For ' + \
|
|
'builtin metrics, passing them in training parameter' + \
|
|
' will invoke monitor automatically.'
|
|
assert callable(metric), msg
|
|
self.metric = metric
|
|
self.history: CallbackContainer.EvalsLog = collections.OrderedDict()
|
|
self.is_cv = is_cv
|
|
|
|
if self.is_cv:
|
|
self.aggregated_cv = None
|
|
|
|
def before_training(self, model):
|
|
'''Function called before training.'''
|
|
for c in self.callbacks:
|
|
model = c.before_training(model=model)
|
|
msg = 'before_training should return the model'
|
|
if self.is_cv:
|
|
assert isinstance(model.cvfolds, list), msg
|
|
else:
|
|
assert isinstance(model, Booster), msg
|
|
return model
|
|
|
|
def after_training(self, model):
|
|
'''Function called after training.'''
|
|
for c in self.callbacks:
|
|
model = c.after_training(model=model)
|
|
msg = 'after_training should return the model'
|
|
if self.is_cv:
|
|
assert isinstance(model.cvfolds, list), msg
|
|
else:
|
|
assert isinstance(model, Booster), msg
|
|
return model
|
|
|
|
def before_iteration(self, model, epoch, dtrain, evals) -> bool:
|
|
'''Function called before training iteration.'''
|
|
return any(c.before_iteration(model, epoch, self.history)
|
|
for c in self.callbacks)
|
|
|
|
def _update_history(self, score, epoch):
|
|
for d in score:
|
|
name, s = d[0], float(d[1])
|
|
if self.is_cv:
|
|
std = float(d[2])
|
|
s = (s, std)
|
|
splited_names = name.split('-')
|
|
data_name = splited_names[0]
|
|
metric_name = '-'.join(splited_names[1:])
|
|
s = _allreduce_metric(s)
|
|
if data_name in self.history:
|
|
data_history = self.history[data_name]
|
|
if metric_name in data_history:
|
|
data_history[metric_name].append(s)
|
|
else:
|
|
data_history[metric_name] = [s]
|
|
else:
|
|
self.history[data_name] = collections.OrderedDict()
|
|
self.history[data_name][metric_name] = [s]
|
|
return False
|
|
|
|
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
|
|
'''Function called after training iteration.'''
|
|
if self.is_cv:
|
|
scores = model.eval(epoch, self.metric)
|
|
scores = _aggcv(scores)
|
|
self.aggregated_cv = scores
|
|
self._update_history(scores, epoch)
|
|
else:
|
|
evals = [] if evals is None else evals
|
|
for _, name in evals:
|
|
assert name.find('-') == -1, 'Dataset name should not contain `-`'
|
|
score = model.eval_set(evals, epoch, self.metric)
|
|
score = score.split()[1:] # into datasets
|
|
# split up `test-error:0.1234`
|
|
score = [tuple(s.split(':')) for s in score]
|
|
self._update_history(score, epoch)
|
|
ret = any(c.after_iteration(model, epoch, self.history)
|
|
for c in self.callbacks)
|
|
return ret
|
|
|
|
|
|
class LearningRateScheduler(TrainingCallback):
|
|
'''Callback function for scheduling learning rate.
|
|
|
|
.. versionadded:: 1.3.0
|
|
|
|
Parameters
|
|
----------
|
|
|
|
learning_rates : callable/collections.Sequence
|
|
If it's a callable object, then it should accept an integer parameter
|
|
`epoch` and returns the corresponding learning rate. Otherwise it
|
|
should be a sequence like list or tuple with the same size of boosting
|
|
rounds.
|
|
|
|
'''
|
|
def __init__(self, learning_rates) -> None:
|
|
assert callable(learning_rates) or \
|
|
isinstance(learning_rates, collections.abc.Sequence)
|
|
if callable(learning_rates):
|
|
self.learning_rates = learning_rates
|
|
else:
|
|
self.learning_rates = lambda epoch: learning_rates[epoch]
|
|
super().__init__()
|
|
|
|
def after_iteration(self, model, epoch, evals_log) -> bool:
|
|
model.set_param('learning_rate', self.learning_rates(epoch))
|
|
return False
|
|
|
|
|
|
# pylint: disable=too-many-instance-attributes
|
|
class EarlyStopping(TrainingCallback):
|
|
"""Callback function for early stopping
|
|
|
|
.. versionadded:: 1.3.0
|
|
|
|
Parameters
|
|
----------
|
|
rounds
|
|
Early stopping rounds.
|
|
metric_name
|
|
Name of metric that is used for early stopping.
|
|
data_name
|
|
Name of dataset that is used for early stopping.
|
|
maximize
|
|
Whether to maximize evaluation metric. None means auto (discouraged).
|
|
save_best
|
|
Whether training should return the best model or the last model.
|
|
abs_tol
|
|
Absolute tolerance for early stopping condition.
|
|
|
|
.. versionadded:: 1.5.0
|
|
|
|
.. code-block:: python
|
|
|
|
clf = xgboost.XGBClassifier(tree_method="gpu_hist")
|
|
es = xgboost.callback.EarlyStopping(
|
|
rounds=2,
|
|
abs_tol=1e-3,
|
|
save_best=True,
|
|
maximize=False,
|
|
data_name="validation_0",
|
|
metric_name="mlogloss",
|
|
)
|
|
|
|
X, y = load_digits(return_X_y=True)
|
|
clf.fit(X, y, eval_set=[(X, y)], callbacks=[es])
|
|
"""
|
|
def __init__(self,
|
|
rounds: int,
|
|
metric_name: Optional[str] = None,
|
|
data_name: Optional[str] = None,
|
|
maximize: Optional[bool] = None,
|
|
save_best: Optional[bool] = False,
|
|
abs_tol: float = 0) -> None:
|
|
self.data = data_name
|
|
self.metric_name = metric_name
|
|
self.rounds = rounds
|
|
self.save_best = save_best
|
|
self.maximize = maximize
|
|
self.stopping_history: CallbackContainer.EvalsLog = {}
|
|
self._tol = abs_tol
|
|
if self._tol < 0:
|
|
raise ValueError("tolerance must be greater or equal to 0.")
|
|
|
|
self.improve_op = None
|
|
|
|
self.current_rounds: int = 0
|
|
self.best_scores: dict = {}
|
|
self.starting_round: int = 0
|
|
super().__init__()
|
|
|
|
def before_training(self, model):
|
|
self.starting_round = model.num_boosted_rounds()
|
|
return model
|
|
|
|
def _update_rounds(self, score, name, metric, model, epoch) -> bool:
|
|
def get_s(x):
|
|
"""get score if it's cross validation history."""
|
|
return x[0] if isinstance(x, tuple) else x
|
|
|
|
def maximize(new, best):
|
|
return numpy.greater(get_s(new) + self._tol, get_s(best))
|
|
|
|
def minimize(new, best):
|
|
return numpy.greater(get_s(best) + self._tol, get_s(new))
|
|
|
|
if self.maximize is None:
|
|
# Just to be compatibility with old behavior before 1.3. We should let
|
|
# user to decide.
|
|
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
|
|
'aucpr@', 'map@', 'ndcg@')
|
|
if metric != 'mape' and any(metric.startswith(x) for x in maximize_metrics):
|
|
self.maximize = True
|
|
else:
|
|
self.maximize = False
|
|
|
|
if self.maximize:
|
|
self.improve_op = maximize
|
|
else:
|
|
self.improve_op = minimize
|
|
|
|
assert self.improve_op
|
|
|
|
if not self.stopping_history: # First round
|
|
self.current_rounds = 0
|
|
self.stopping_history[name] = {}
|
|
self.stopping_history[name][metric] = [score]
|
|
self.best_scores[name] = {}
|
|
self.best_scores[name][metric] = [score]
|
|
model.set_attr(best_score=str(score), best_iteration=str(epoch))
|
|
elif not self.improve_op(score, self.best_scores[name][metric][-1]):
|
|
# Not improved
|
|
self.stopping_history[name][metric].append(score)
|
|
self.current_rounds += 1
|
|
else: # Improved
|
|
self.stopping_history[name][metric].append(score)
|
|
self.best_scores[name][metric].append(score)
|
|
record = self.stopping_history[name][metric][-1]
|
|
model.set_attr(best_score=str(record), best_iteration=str(epoch))
|
|
self.current_rounds = 0 # reset
|
|
|
|
if self.current_rounds >= self.rounds:
|
|
# Should stop
|
|
return True
|
|
return False
|
|
|
|
def after_iteration(self, model, epoch: int,
|
|
evals_log: CallbackContainer.EvalsLog) -> bool:
|
|
epoch += self.starting_round # training continuation
|
|
msg = 'Must have at least 1 validation dataset for early stopping.'
|
|
assert len(evals_log.keys()) >= 1, msg
|
|
data_name = ''
|
|
if self.data:
|
|
for d, _ in evals_log.items():
|
|
if d == self.data:
|
|
data_name = d
|
|
if not data_name:
|
|
raise ValueError('No dataset named:', self.data)
|
|
else:
|
|
# Use the last one as default.
|
|
data_name = list(evals_log.keys())[-1]
|
|
assert isinstance(data_name, str) and data_name
|
|
data_log = evals_log[data_name]
|
|
|
|
# Filter out scores that can not be used for early stopping.
|
|
if self.metric_name:
|
|
metric_name = self.metric_name
|
|
else:
|
|
# Use last metric by default.
|
|
assert isinstance(data_log, collections.OrderedDict)
|
|
metric_name = list(data_log.keys())[-1]
|
|
score = data_log[metric_name][-1]
|
|
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
|
|
|
def after_training(self, model):
|
|
try:
|
|
if self.save_best:
|
|
model = model[: int(model.attr("best_iteration")) + 1]
|
|
except XGBoostError as e:
|
|
raise XGBoostError(
|
|
"`save_best` is not applicable to current booster"
|
|
) from e
|
|
return model
|
|
|
|
|
|
class EvaluationMonitor(TrainingCallback):
|
|
'''Print the evaluation result at each iteration.
|
|
|
|
.. versionadded:: 1.3.0
|
|
|
|
Parameters
|
|
----------
|
|
|
|
metric : callable
|
|
Extra user defined metric.
|
|
rank : int
|
|
Which worker should be used for printing the result.
|
|
period : int
|
|
How many epoches between printing.
|
|
show_stdv : bool
|
|
Used in cv to show standard deviation. Users should not specify it.
|
|
'''
|
|
def __init__(self, rank=0, period=1, show_stdv=False) -> None:
|
|
self.printer_rank = rank
|
|
self.show_stdv = show_stdv
|
|
self.period = period
|
|
assert period > 0
|
|
# last error message, useful when early stopping and period are used together.
|
|
self._latest: Optional[str] = None
|
|
super().__init__()
|
|
|
|
def _fmt_metric(self, data, metric, score, std) -> str:
|
|
if std is not None and self.show_stdv:
|
|
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
|
|
else:
|
|
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
|
|
return msg
|
|
|
|
def after_iteration(self, model, epoch: int,
|
|
evals_log: CallbackContainer.EvalsLog) -> bool:
|
|
if not evals_log:
|
|
return False
|
|
|
|
msg: str = f'[{epoch}]'
|
|
if rabit.get_rank() == self.printer_rank:
|
|
for data, metric in evals_log.items():
|
|
for metric_name, log in metric.items():
|
|
stdv: Optional[float] = None
|
|
if isinstance(log[-1], tuple):
|
|
score = log[-1][0]
|
|
stdv = log[-1][1]
|
|
else:
|
|
score = log[-1]
|
|
msg += self._fmt_metric(data, metric_name, score, stdv)
|
|
msg += '\n'
|
|
|
|
if (epoch % self.period) == 0 or self.period == 1:
|
|
rabit.tracker_print(msg)
|
|
self._latest = None
|
|
else:
|
|
# There is skipped message
|
|
self._latest = msg
|
|
return False
|
|
|
|
def after_training(self, model):
|
|
if rabit.get_rank() == self.printer_rank and self._latest is not None:
|
|
rabit.tracker_print(self._latest)
|
|
return model
|
|
|
|
|
|
class TrainingCheckPoint(TrainingCallback):
|
|
'''Checkpointing operation.
|
|
|
|
.. versionadded:: 1.3.0
|
|
|
|
Parameters
|
|
----------
|
|
|
|
directory : os.PathLike
|
|
Output model directory.
|
|
name : str
|
|
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
|
name_2.json ....
|
|
as_pickle : boolean
|
|
When set to Ture, all training parameters will be saved in pickle format, instead
|
|
of saving only the model.
|
|
iterations : int
|
|
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
|
reduce performance hit.
|
|
|
|
'''
|
|
def __init__(self, directory: os.PathLike, name: str = 'model',
|
|
as_pickle=False, iterations: int = 100):
|
|
self._path = directory
|
|
self._name = name
|
|
self._as_pickle = as_pickle
|
|
self._iterations = iterations
|
|
self._epoch = 0
|
|
super().__init__()
|
|
|
|
def after_iteration(self, model, epoch: int,
|
|
evals_log: CallbackContainer.EvalsLog) -> bool:
|
|
if self._epoch == self._iterations:
|
|
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
|
('.pkl' if self._as_pickle else '.json'))
|
|
self._epoch = 0
|
|
if rabit.get_rank() == 0:
|
|
if self._as_pickle:
|
|
with open(path, 'wb') as fd:
|
|
pickle.dump(model, fd)
|
|
else:
|
|
model.save_model(path)
|
|
self._epoch += 1
|
|
return False
|
|
|
|
|
|
class LegacyCallbacks:
|
|
'''Adapter for legacy callback functions.
|
|
|
|
.. versionadded:: 1.3.0
|
|
|
|
Parameters
|
|
----------
|
|
|
|
callbacks : Sequence
|
|
A sequence of legacy callbacks (callbacks that are not instance of
|
|
TrainingCallback)
|
|
start_iteration : int
|
|
Begining iteration.
|
|
end_iteration : int
|
|
End iteration, normally is the number of boosting rounds.
|
|
evals : Sequence
|
|
Sequence of evaluation dataset tuples.
|
|
feval : Custom evaluation metric.
|
|
'''
|
|
def __init__(self, callbacks, start_iteration, end_iteration,
|
|
feval, cvfolds=None):
|
|
self.callbacks_before_iter = [
|
|
cb for cb in callbacks
|
|
if cb.__dict__.get('before_iteration', False)]
|
|
self.callbacks_after_iter = [
|
|
cb for cb in callbacks
|
|
if not cb.__dict__.get('before_iteration', False)]
|
|
|
|
self.start_iteration = start_iteration
|
|
self.end_iteration = end_iteration
|
|
self.cvfolds = cvfolds
|
|
|
|
self.feval = feval
|
|
assert self.feval is None or callable(self.feval)
|
|
|
|
if cvfolds is not None:
|
|
self.aggregated_cv = None
|
|
|
|
super().__init__()
|
|
|
|
def before_training(self, model):
|
|
'''Nothing to do for legacy callbacks'''
|
|
return model
|
|
|
|
def after_training(self, model):
|
|
'''Nothing to do for legacy callbacks'''
|
|
return model
|
|
|
|
def before_iteration(self, model, epoch, dtrain, evals):
|
|
'''Called before each iteration.'''
|
|
for cb in self.callbacks_before_iter:
|
|
rank = rabit.get_rank()
|
|
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
|
|
cvfolds=self.cvfolds,
|
|
iteration=epoch,
|
|
begin_iteration=self.start_iteration,
|
|
end_iteration=self.end_iteration,
|
|
rank=rank,
|
|
evaluation_result_list=None))
|
|
return False
|
|
|
|
def after_iteration(self, model, epoch, dtrain, evals):
|
|
'''Called after each iteration.'''
|
|
evaluation_result_list = []
|
|
if self.cvfolds is not None:
|
|
# dtrain is not used here.
|
|
scores = model.eval(epoch, self.feval)
|
|
self.aggregated_cv = _aggcv(scores)
|
|
evaluation_result_list = self.aggregated_cv
|
|
|
|
if evals:
|
|
# When cv is used, evals are embedded into folds.
|
|
assert self.cvfolds is None
|
|
bst_eval_set = model.eval_set(evals, epoch, self.feval)
|
|
if isinstance(bst_eval_set, STRING_TYPES):
|
|
msg = bst_eval_set
|
|
else:
|
|
msg = bst_eval_set.decode()
|
|
res = [x.split(':') for x in msg.split()]
|
|
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
|
|
|
|
try:
|
|
for cb in self.callbacks_after_iter:
|
|
rank = rabit.get_rank()
|
|
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
|
|
cvfolds=self.cvfolds,
|
|
iteration=epoch,
|
|
begin_iteration=self.start_iteration,
|
|
end_iteration=self.end_iteration,
|
|
rank=rank,
|
|
evaluation_result_list=evaluation_result_list))
|
|
except EarlyStopException:
|
|
return True
|
|
|
|
return False
|