Support early stopping with training continuation, correct num boosted rounds. (#6506)
* Implement early stopping with training continuation. * Add new C API for obtaining boosted rounds. * Fix off by 1 in `save_best`. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -6,7 +6,7 @@ from abc import ABC
|
||||
import collections
|
||||
import os
|
||||
import pickle
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Optional, Union, Dict, Tuple
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
@@ -285,11 +285,13 @@ class TrainingCallback(ABC):
|
||||
'''Run after training is finished.'''
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, evals_log):
|
||||
def before_iteration(self, model, epoch: int,
|
||||
evals_log: 'CallbackContainer.EvalsLog') -> bool:
|
||||
'''Run before each iteration. Return True when training should stop.'''
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
def after_iteration(self, model, epoch: int,
|
||||
evals_log: 'CallbackContainer.EvalsLog') -> bool:
|
||||
'''Run after each iteration. Return True when training should stop.'''
|
||||
return False
|
||||
|
||||
@@ -346,8 +348,13 @@ class CallbackContainer:
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
'''
|
||||
def __init__(self, callbacks: List[TrainingCallback],
|
||||
metric: Callable = None, is_cv: bool = False):
|
||||
|
||||
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
|
||||
|
||||
def __init__(self,
|
||||
callbacks: List[TrainingCallback],
|
||||
metric: Callable = None,
|
||||
is_cv: bool = False):
|
||||
self.callbacks = set(callbacks)
|
||||
if metric is not None:
|
||||
msg = 'metric must be callable object for monitoring. For ' + \
|
||||
@@ -355,7 +362,7 @@ class CallbackContainer:
|
||||
' will invoke monitor automatically.'
|
||||
assert callable(metric), msg
|
||||
self.metric = metric
|
||||
self.history = collections.OrderedDict()
|
||||
self.history: CallbackContainer.EvalsLog = collections.OrderedDict()
|
||||
self.is_cv = is_cv
|
||||
|
||||
if self.is_cv:
|
||||
@@ -383,7 +390,7 @@ class CallbackContainer:
|
||||
assert isinstance(model, Booster), msg
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, dtrain, evals):
|
||||
def before_iteration(self, model, epoch, dtrain, evals) -> bool:
|
||||
'''Function called before training iteration.'''
|
||||
return any(c.before_iteration(model, epoch, self.history)
|
||||
for c in self.callbacks)
|
||||
@@ -409,7 +416,7 @@ class CallbackContainer:
|
||||
self.history[data_name][metric_name] = [s]
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, dtrain, evals):
|
||||
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
|
||||
'''Function called after training iteration.'''
|
||||
if self.is_cv:
|
||||
scores = model.eval(epoch, self.metric)
|
||||
@@ -445,7 +452,7 @@ class LearningRateScheduler(TrainingCallback):
|
||||
rounds.
|
||||
|
||||
'''
|
||||
def __init__(self, learning_rates):
|
||||
def __init__(self, learning_rates) -> None:
|
||||
assert callable(learning_rates) or \
|
||||
isinstance(learning_rates, collections.abc.Sequence)
|
||||
if callable(learning_rates):
|
||||
@@ -454,41 +461,42 @@ class LearningRateScheduler(TrainingCallback):
|
||||
self.learning_rates = lambda epoch: learning_rates[epoch]
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
def after_iteration(self, model, epoch, evals_log) -> bool:
|
||||
model.set_param('learning_rate', self.learning_rates(epoch))
|
||||
return False
|
||||
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
class EarlyStopping(TrainingCallback):
|
||||
''' Callback function for early stopping
|
||||
"""Callback function for early stopping
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rounds : int
|
||||
rounds
|
||||
Early stopping rounds.
|
||||
metric_name : str
|
||||
metric_name
|
||||
Name of metric that is used for early stopping.
|
||||
data_name: str
|
||||
data_name
|
||||
Name of dataset that is used for early stopping.
|
||||
maximize : bool
|
||||
maximize
|
||||
Whether to maximize evaluation metric. None means auto (discouraged).
|
||||
save_best : bool
|
||||
save_best
|
||||
Whether training should return the best model or the last model.
|
||||
'''
|
||||
"""
|
||||
def __init__(self,
|
||||
rounds,
|
||||
metric_name=None,
|
||||
data_name=None,
|
||||
maximize=None,
|
||||
save_best=False):
|
||||
rounds: int,
|
||||
metric_name: Optional[str] = None,
|
||||
data_name: Optional[str] = None,
|
||||
maximize: Optional[bool] = None,
|
||||
save_best: Optional[bool] = False) -> None:
|
||||
self.data = data_name
|
||||
self.metric_name = metric_name
|
||||
self.rounds = rounds
|
||||
self.save_best = save_best
|
||||
self.maximize = maximize
|
||||
self.stopping_history = {}
|
||||
self.stopping_history: CallbackContainer.EvalsLog = {}
|
||||
|
||||
if self.maximize is not None:
|
||||
if self.maximize:
|
||||
@@ -496,11 +504,16 @@ class EarlyStopping(TrainingCallback):
|
||||
else:
|
||||
self.improve_op = lambda x, y: x < y
|
||||
|
||||
self.current_rounds = 0
|
||||
self.best_scores = {}
|
||||
self.current_rounds: int = 0
|
||||
self.best_scores: dict = {}
|
||||
self.starting_round: int = 0
|
||||
super().__init__()
|
||||
|
||||
def _update_rounds(self, score, name, metric, model, epoch):
|
||||
def before_training(self, model):
|
||||
self.starting_round = model.num_boosted_rounds()
|
||||
return model
|
||||
|
||||
def _update_rounds(self, score, name, metric, model, epoch) -> bool:
|
||||
# Just to be compatibility with old behavior before 1.3. We should let
|
||||
# user to decide.
|
||||
if self.maximize is None:
|
||||
@@ -536,7 +549,9 @@ class EarlyStopping(TrainingCallback):
|
||||
return True
|
||||
return False
|
||||
|
||||
def after_iteration(self, model: Booster, epoch, evals_log):
|
||||
def after_iteration(self, model, epoch: int,
|
||||
evals_log: CallbackContainer.EvalsLog) -> bool:
|
||||
epoch += self.starting_round # training continuation
|
||||
msg = 'Must have at least 1 validation dataset for early stopping.'
|
||||
assert len(evals_log.keys()) >= 1, msg
|
||||
data_name = ''
|
||||
@@ -562,12 +577,14 @@ class EarlyStopping(TrainingCallback):
|
||||
score = data_log[metric_name][-1]
|
||||
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
||||
|
||||
def after_training(self, model: Booster):
|
||||
def after_training(self, model):
|
||||
try:
|
||||
if self.save_best:
|
||||
model = model[: int(model.attr('best_iteration'))]
|
||||
model = model[: int(model.attr("best_iteration")) + 1]
|
||||
except XGBoostError as e:
|
||||
raise XGBoostError('`save_best` is not applicable to current booster') from e
|
||||
raise XGBoostError(
|
||||
"`save_best` is not applicable to current booster"
|
||||
) from e
|
||||
return model
|
||||
|
||||
|
||||
@@ -588,36 +605,37 @@ class EvaluationMonitor(TrainingCallback):
|
||||
show_stdv : bool
|
||||
Used in cv to show standard deviation. Users should not specify it.
|
||||
'''
|
||||
def __init__(self, rank=0, period=1, show_stdv=False):
|
||||
def __init__(self, rank=0, period=1, show_stdv=False) -> None:
|
||||
self.printer_rank = rank
|
||||
self.show_stdv = show_stdv
|
||||
self.period = period
|
||||
assert period > 0
|
||||
# last error message, useful when early stopping and period are used together.
|
||||
self._latest = None
|
||||
self._latest: Optional[str] = None
|
||||
super().__init__()
|
||||
|
||||
def _fmt_metric(self, data, metric, score, std):
|
||||
def _fmt_metric(self, data, metric, score, std) -> str:
|
||||
if std is not None and self.show_stdv:
|
||||
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
|
||||
else:
|
||||
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
|
||||
return msg
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
def after_iteration(self, model, epoch: int,
|
||||
evals_log: CallbackContainer.EvalsLog) -> bool:
|
||||
if not evals_log:
|
||||
return False
|
||||
|
||||
msg = f'[{epoch}]'
|
||||
msg: str = f'[{epoch}]'
|
||||
if rabit.get_rank() == self.printer_rank:
|
||||
for data, metric in evals_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
stdv: Optional[float] = None
|
||||
if isinstance(log[-1], tuple):
|
||||
score = log[-1][0]
|
||||
stdv = log[-1][1]
|
||||
else:
|
||||
score = log[-1]
|
||||
stdv = None
|
||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||
msg += '\n'
|
||||
|
||||
@@ -665,7 +683,8 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
self._epoch = 0
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
def after_iteration(self, model, epoch: int,
|
||||
evals_log: CallbackContainer.EvalsLog) -> bool:
|
||||
if self._epoch == self._iterations:
|
||||
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
||||
('.pkl' if self._as_pickle else '.json'))
|
||||
@@ -677,6 +696,7 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
else:
|
||||
model.save_model(path)
|
||||
self._epoch += 1
|
||||
return False
|
||||
|
||||
|
||||
class LegacyCallbacks:
|
||||
|
||||
@@ -1177,23 +1177,6 @@ class Booster(object):
|
||||
"""
|
||||
return self.__copy__()
|
||||
|
||||
def load_rabit_checkpoint(self):
|
||||
"""Initialize the model by load from rabit checkpoint.
|
||||
|
||||
Returns
|
||||
-------
|
||||
version: integer
|
||||
The version number of the model.
|
||||
"""
|
||||
version = ctypes.c_int()
|
||||
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
|
||||
self.handle, ctypes.byref(version)))
|
||||
return version.value
|
||||
|
||||
def save_rabit_checkpoint(self):
|
||||
"""Save the current booster to rabit checkpoint."""
|
||||
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))
|
||||
|
||||
def attr(self, key):
|
||||
"""Get attribute string from the Booster.
|
||||
|
||||
@@ -1745,6 +1728,17 @@ class Booster(object):
|
||||
else:
|
||||
raise TypeError('Unknown file type: ', fname)
|
||||
|
||||
def num_boosted_rounds(self) -> int:
|
||||
'''Get number of boosted rounds. For gblinear this is reset to 0 after
|
||||
serializing the model.
|
||||
|
||||
'''
|
||||
rounds = ctypes.c_int()
|
||||
assert self.handle is not None
|
||||
_check_call(_LIB.XGBoosterBoostedRounds(
|
||||
self.handle, ctypes.byref(rounds)))
|
||||
return rounds.value
|
||||
|
||||
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
|
||||
"""Dump model into a text or JSON file. Unlike `save_model`, the
|
||||
output format is primarily used for visualization or interpretation,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import copy
|
||||
import warnings
|
||||
import json
|
||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
|
||||
import numpy as np
|
||||
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
||||
from .training import train
|
||||
@@ -494,6 +495,22 @@ class XGBModel(XGBModelBase):
|
||||
# Delete the attribute after load
|
||||
self.get_booster().set_attr(scikit_learn=None)
|
||||
|
||||
def _configure_fit(
|
||||
self,
|
||||
booster: Optional[Booster],
|
||||
eval_metric: Optional[Union[Callable, str, List[str]]],
|
||||
params: Dict[str, Any],
|
||||
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
|
||||
model = self._Booster if hasattr(self, "_Booster") else None
|
||||
model = booster if booster is not None else model
|
||||
feval = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
eval_metric = None
|
||||
else:
|
||||
params.update({"eval_metric": eval_metric})
|
||||
return model, feval, params
|
||||
|
||||
@_deprecate_positional_args
|
||||
def fit(self, X, y, *, sample_weight=None, base_margin=None,
|
||||
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
||||
@@ -586,19 +603,13 @@ class XGBModel(XGBModelBase):
|
||||
else:
|
||||
obj = None
|
||||
|
||||
feval = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
eval_metric = None
|
||||
else:
|
||||
params.update({'eval_metric': eval_metric})
|
||||
|
||||
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
||||
self._Booster = train(params, train_dmatrix,
|
||||
self.get_num_boosting_rounds(), evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
evals_result=evals_result,
|
||||
obj=obj, feval=feval,
|
||||
verbose_eval=verbose, xgb_model=xgb_model,
|
||||
verbose_eval=verbose, xgb_model=model,
|
||||
callbacks=callbacks)
|
||||
|
||||
if evals_result:
|
||||
@@ -857,27 +868,20 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
not np.array_equal(self.classes_, np.arange(self.n_classes_))):
|
||||
raise ValueError(label_encoding_check_error)
|
||||
|
||||
xgb_options = self.get_xgb_params()
|
||||
params = self.get_xgb_params()
|
||||
|
||||
if callable(self.objective):
|
||||
obj = _objective_decorator(self.objective)
|
||||
# Use default value. Is it really not used ?
|
||||
xgb_options["objective"] = "binary:logistic"
|
||||
params["objective"] = "binary:logistic"
|
||||
else:
|
||||
obj = None
|
||||
|
||||
if self.n_classes_ > 2:
|
||||
# Switch to using a multiclass objective in the underlying
|
||||
# XGB instance
|
||||
xgb_options['objective'] = 'multi:softprob'
|
||||
xgb_options['num_class'] = self.n_classes_
|
||||
|
||||
feval = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
eval_metric = None
|
||||
else:
|
||||
xgb_options.update({"eval_metric": eval_metric})
|
||||
params['objective'] = 'multi:softprob'
|
||||
params['num_class'] = self.n_classes_
|
||||
|
||||
if self.use_label_encoder:
|
||||
if not can_use_label_encoder:
|
||||
@@ -891,6 +895,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
else:
|
||||
label_transform = (lambda x: x)
|
||||
|
||||
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
||||
if len(X.shape) != 2:
|
||||
# Simply raise an error here since there might be many
|
||||
# different ways of reshaping
|
||||
@@ -906,15 +911,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set,
|
||||
eval_group=None, label_transform=label_transform)
|
||||
|
||||
self._Booster = train(xgb_options, train_dmatrix,
|
||||
self._Booster = train(params, train_dmatrix,
|
||||
self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
evals_result=evals_result, obj=obj, feval=feval,
|
||||
verbose_eval=verbose, xgb_model=xgb_model,
|
||||
verbose_eval=verbose, xgb_model=model,
|
||||
callbacks=callbacks)
|
||||
|
||||
self.objective = xgb_options["objective"]
|
||||
self.objective = params["objective"]
|
||||
if evals_result:
|
||||
for val in evals_result.items():
|
||||
evals_result_key = list(val[1].keys())[0]
|
||||
|
||||
@@ -4,11 +4,10 @@
|
||||
"""Training Library containing training routines."""
|
||||
import warnings
|
||||
import copy
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from .core import Booster, XGBoostError
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||
from . import rabit
|
||||
from . import callback
|
||||
|
||||
|
||||
@@ -51,28 +50,12 @@ def _train_internal(params, dtrain,
|
||||
evals = list(evals)
|
||||
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
nboost = 0
|
||||
num_parallel_tree = 1
|
||||
|
||||
if xgb_model is not None:
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals],
|
||||
model_file=xgb_model)
|
||||
nboost = len(bst.get_dump())
|
||||
|
||||
_params = dict(params) if isinstance(params, list) else params
|
||||
|
||||
if 'num_parallel_tree' in _params and _params[
|
||||
'num_parallel_tree'] is not None:
|
||||
num_parallel_tree = _params['num_parallel_tree']
|
||||
nboost //= num_parallel_tree
|
||||
if 'num_class' in _params and _params['num_class'] is not None:
|
||||
nboost //= _params['num_class']
|
||||
|
||||
# Distributed code: Load the checkpoint from rabit.
|
||||
version = bst.load_rabit_checkpoint()
|
||||
assert rabit.get_world_size() != 1 or version == 0
|
||||
start_iteration = int(version / 2)
|
||||
nboost += start_iteration
|
||||
start_iteration = 0
|
||||
|
||||
is_new_callback = _is_new_callback(callbacks)
|
||||
if is_new_callback:
|
||||
@@ -92,26 +75,13 @@ def _train_internal(params, dtrain,
|
||||
show_stdv=False, cvfolds=None)
|
||||
|
||||
bst = callbacks.before_training(bst)
|
||||
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
if callbacks.before_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
# Distributed code: need to resume to this point.
|
||||
# Skip the first update if it is a recovery step.
|
||||
if version % 2 == 0:
|
||||
bst.update(dtrain, i, obj)
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
assert rabit.get_world_size() == 1 or version == rabit.version_number()
|
||||
|
||||
nboost += 1
|
||||
# check evaluation result.
|
||||
bst.update(dtrain, i, obj)
|
||||
if callbacks.after_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
# do checkpoint after evaluation, in case evaluation also updates
|
||||
# booster.
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
bst = callbacks.after_training(bst)
|
||||
|
||||
@@ -122,7 +92,12 @@ def _train_internal(params, dtrain,
|
||||
bst.best_score = float(bst.attr('best_score'))
|
||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||
else:
|
||||
bst.best_iteration = nboost - 1
|
||||
bst.best_iteration = bst.num_boosted_rounds() - 1
|
||||
try:
|
||||
num_parallel_tree = int(json.loads(bst.save_config())['learner'][
|
||||
'gradient_booster']['gbtree_train_param']['num_parallel_tree'])
|
||||
except KeyError: # gblinear
|
||||
num_parallel_tree = 1
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||
# Copy to serialise and unserialise booster to reset state and free
|
||||
# training memory
|
||||
@@ -234,7 +209,7 @@ class CVPack(object):
|
||||
|
||||
|
||||
class _PackedBooster:
|
||||
def __init__(self, cvfolds):
|
||||
def __init__(self, cvfolds) -> None:
|
||||
self.cvfolds = cvfolds
|
||||
|
||||
def update(self, iteration, obj):
|
||||
@@ -262,6 +237,10 @@ class _PackedBooster:
|
||||
ret = self.cvfolds[0].bst.attr('best_iteration')
|
||||
return int(ret)
|
||||
|
||||
def num_boosted_rounds(self) -> int:
|
||||
'''Number of boosted rounds.'''
|
||||
return self.cvfolds[0].bst.num_boosted_rounds()
|
||||
|
||||
|
||||
def groups_to_rows(groups, boundaries):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user