Support early stopping with training continuation, correct num boosted rounds. (#6506)
* Implement early stopping with training continuation. * Add new C API for obtaining boosted rounds. * Fix off by 1 in `save_best`. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
125b3c0f2d
commit
ca3da55de4
@ -614,6 +614,15 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
|
|||||||
int end_layer, int step,
|
int end_layer, int step,
|
||||||
BoosterHandle *out);
|
BoosterHandle *out);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get number of boosted rounds from gradient booster. When process_type is
|
||||||
|
* update, this number might drop due to removed tree.
|
||||||
|
* \param handle Handle to booster.
|
||||||
|
* \param out Pointer to output integer.
|
||||||
|
* \return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters
|
* \brief set parameters
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
|
|||||||
@ -79,6 +79,9 @@ class GradientBooster : public Model, public Configurable {
|
|||||||
virtual bool AllowLazyCheckPoint() const {
|
virtual bool AllowLazyCheckPoint() const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
/*! \brief Return number of boosted rounds.
|
||||||
|
*/
|
||||||
|
virtual int32_t BoostedRounds() const = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief perform update to the model(boosting)
|
* \brief perform update to the model(boosting)
|
||||||
* \param p_fmat feature matrix that provide access to features
|
* \param p_fmat feature matrix that provide access to features
|
||||||
|
|||||||
@ -134,6 +134,11 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
|||||||
HostDeviceVector<bst_float> **out_preds,
|
HostDeviceVector<bst_float> **out_preds,
|
||||||
uint32_t layer_begin, uint32_t layer_end) = 0;
|
uint32_t layer_begin, uint32_t layer_end) = 0;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* \brief Get number of boosted rounds from gradient booster.
|
||||||
|
*/
|
||||||
|
virtual int32_t BoostedRounds() const = 0;
|
||||||
|
|
||||||
void LoadModel(Json const& in) override = 0;
|
void LoadModel(Json const& in) override = 0;
|
||||||
void SaveModel(Json* out) const override = 0;
|
void SaveModel(Json* out) const override = 0;
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from abc import ABC
|
|||||||
import collections
|
import collections
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Optional, Union, Dict, Tuple
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from . import rabit
|
from . import rabit
|
||||||
@ -285,11 +285,13 @@ class TrainingCallback(ABC):
|
|||||||
'''Run after training is finished.'''
|
'''Run after training is finished.'''
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def before_iteration(self, model, epoch, evals_log):
|
def before_iteration(self, model, epoch: int,
|
||||||
|
evals_log: 'CallbackContainer.EvalsLog') -> bool:
|
||||||
'''Run before each iteration. Return True when training should stop.'''
|
'''Run before each iteration. Return True when training should stop.'''
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, evals_log):
|
def after_iteration(self, model, epoch: int,
|
||||||
|
evals_log: 'CallbackContainer.EvalsLog') -> bool:
|
||||||
'''Run after each iteration. Return True when training should stop.'''
|
'''Run after each iteration. Return True when training should stop.'''
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -346,8 +348,13 @@ class CallbackContainer:
|
|||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(self, callbacks: List[TrainingCallback],
|
|
||||||
metric: Callable = None, is_cv: bool = False):
|
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
callbacks: List[TrainingCallback],
|
||||||
|
metric: Callable = None,
|
||||||
|
is_cv: bool = False):
|
||||||
self.callbacks = set(callbacks)
|
self.callbacks = set(callbacks)
|
||||||
if metric is not None:
|
if metric is not None:
|
||||||
msg = 'metric must be callable object for monitoring. For ' + \
|
msg = 'metric must be callable object for monitoring. For ' + \
|
||||||
@ -355,7 +362,7 @@ class CallbackContainer:
|
|||||||
' will invoke monitor automatically.'
|
' will invoke monitor automatically.'
|
||||||
assert callable(metric), msg
|
assert callable(metric), msg
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
self.history = collections.OrderedDict()
|
self.history: CallbackContainer.EvalsLog = collections.OrderedDict()
|
||||||
self.is_cv = is_cv
|
self.is_cv = is_cv
|
||||||
|
|
||||||
if self.is_cv:
|
if self.is_cv:
|
||||||
@ -383,7 +390,7 @@ class CallbackContainer:
|
|||||||
assert isinstance(model, Booster), msg
|
assert isinstance(model, Booster), msg
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def before_iteration(self, model, epoch, dtrain, evals):
|
def before_iteration(self, model, epoch, dtrain, evals) -> bool:
|
||||||
'''Function called before training iteration.'''
|
'''Function called before training iteration.'''
|
||||||
return any(c.before_iteration(model, epoch, self.history)
|
return any(c.before_iteration(model, epoch, self.history)
|
||||||
for c in self.callbacks)
|
for c in self.callbacks)
|
||||||
@ -409,7 +416,7 @@ class CallbackContainer:
|
|||||||
self.history[data_name][metric_name] = [s]
|
self.history[data_name][metric_name] = [s]
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, dtrain, evals):
|
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
|
||||||
'''Function called after training iteration.'''
|
'''Function called after training iteration.'''
|
||||||
if self.is_cv:
|
if self.is_cv:
|
||||||
scores = model.eval(epoch, self.metric)
|
scores = model.eval(epoch, self.metric)
|
||||||
@ -445,7 +452,7 @@ class LearningRateScheduler(TrainingCallback):
|
|||||||
rounds.
|
rounds.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(self, learning_rates):
|
def __init__(self, learning_rates) -> None:
|
||||||
assert callable(learning_rates) or \
|
assert callable(learning_rates) or \
|
||||||
isinstance(learning_rates, collections.abc.Sequence)
|
isinstance(learning_rates, collections.abc.Sequence)
|
||||||
if callable(learning_rates):
|
if callable(learning_rates):
|
||||||
@ -454,41 +461,42 @@ class LearningRateScheduler(TrainingCallback):
|
|||||||
self.learning_rates = lambda epoch: learning_rates[epoch]
|
self.learning_rates = lambda epoch: learning_rates[epoch]
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, evals_log):
|
def after_iteration(self, model, epoch, evals_log) -> bool:
|
||||||
model.set_param('learning_rate', self.learning_rates(epoch))
|
model.set_param('learning_rate', self.learning_rates(epoch))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-instance-attributes
|
# pylint: disable=too-many-instance-attributes
|
||||||
class EarlyStopping(TrainingCallback):
|
class EarlyStopping(TrainingCallback):
|
||||||
''' Callback function for early stopping
|
"""Callback function for early stopping
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
rounds : int
|
rounds
|
||||||
Early stopping rounds.
|
Early stopping rounds.
|
||||||
metric_name : str
|
metric_name
|
||||||
Name of metric that is used for early stopping.
|
Name of metric that is used for early stopping.
|
||||||
data_name: str
|
data_name
|
||||||
Name of dataset that is used for early stopping.
|
Name of dataset that is used for early stopping.
|
||||||
maximize : bool
|
maximize
|
||||||
Whether to maximize evaluation metric. None means auto (discouraged).
|
Whether to maximize evaluation metric. None means auto (discouraged).
|
||||||
save_best : bool
|
save_best
|
||||||
Whether training should return the best model or the last model.
|
Whether training should return the best model or the last model.
|
||||||
'''
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
rounds,
|
rounds: int,
|
||||||
metric_name=None,
|
metric_name: Optional[str] = None,
|
||||||
data_name=None,
|
data_name: Optional[str] = None,
|
||||||
maximize=None,
|
maximize: Optional[bool] = None,
|
||||||
save_best=False):
|
save_best: Optional[bool] = False) -> None:
|
||||||
self.data = data_name
|
self.data = data_name
|
||||||
self.metric_name = metric_name
|
self.metric_name = metric_name
|
||||||
self.rounds = rounds
|
self.rounds = rounds
|
||||||
self.save_best = save_best
|
self.save_best = save_best
|
||||||
self.maximize = maximize
|
self.maximize = maximize
|
||||||
self.stopping_history = {}
|
self.stopping_history: CallbackContainer.EvalsLog = {}
|
||||||
|
|
||||||
if self.maximize is not None:
|
if self.maximize is not None:
|
||||||
if self.maximize:
|
if self.maximize:
|
||||||
@ -496,11 +504,16 @@ class EarlyStopping(TrainingCallback):
|
|||||||
else:
|
else:
|
||||||
self.improve_op = lambda x, y: x < y
|
self.improve_op = lambda x, y: x < y
|
||||||
|
|
||||||
self.current_rounds = 0
|
self.current_rounds: int = 0
|
||||||
self.best_scores = {}
|
self.best_scores: dict = {}
|
||||||
|
self.starting_round: int = 0
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _update_rounds(self, score, name, metric, model, epoch):
|
def before_training(self, model):
|
||||||
|
self.starting_round = model.num_boosted_rounds()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _update_rounds(self, score, name, metric, model, epoch) -> bool:
|
||||||
# Just to be compatibility with old behavior before 1.3. We should let
|
# Just to be compatibility with old behavior before 1.3. We should let
|
||||||
# user to decide.
|
# user to decide.
|
||||||
if self.maximize is None:
|
if self.maximize is None:
|
||||||
@ -536,7 +549,9 @@ class EarlyStopping(TrainingCallback):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def after_iteration(self, model: Booster, epoch, evals_log):
|
def after_iteration(self, model, epoch: int,
|
||||||
|
evals_log: CallbackContainer.EvalsLog) -> bool:
|
||||||
|
epoch += self.starting_round # training continuation
|
||||||
msg = 'Must have at least 1 validation dataset for early stopping.'
|
msg = 'Must have at least 1 validation dataset for early stopping.'
|
||||||
assert len(evals_log.keys()) >= 1, msg
|
assert len(evals_log.keys()) >= 1, msg
|
||||||
data_name = ''
|
data_name = ''
|
||||||
@ -562,12 +577,14 @@ class EarlyStopping(TrainingCallback):
|
|||||||
score = data_log[metric_name][-1]
|
score = data_log[metric_name][-1]
|
||||||
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
||||||
|
|
||||||
def after_training(self, model: Booster):
|
def after_training(self, model):
|
||||||
try:
|
try:
|
||||||
if self.save_best:
|
if self.save_best:
|
||||||
model = model[: int(model.attr('best_iteration'))]
|
model = model[: int(model.attr("best_iteration")) + 1]
|
||||||
except XGBoostError as e:
|
except XGBoostError as e:
|
||||||
raise XGBoostError('`save_best` is not applicable to current booster') from e
|
raise XGBoostError(
|
||||||
|
"`save_best` is not applicable to current booster"
|
||||||
|
) from e
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -588,36 +605,37 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
show_stdv : bool
|
show_stdv : bool
|
||||||
Used in cv to show standard deviation. Users should not specify it.
|
Used in cv to show standard deviation. Users should not specify it.
|
||||||
'''
|
'''
|
||||||
def __init__(self, rank=0, period=1, show_stdv=False):
|
def __init__(self, rank=0, period=1, show_stdv=False) -> None:
|
||||||
self.printer_rank = rank
|
self.printer_rank = rank
|
||||||
self.show_stdv = show_stdv
|
self.show_stdv = show_stdv
|
||||||
self.period = period
|
self.period = period
|
||||||
assert period > 0
|
assert period > 0
|
||||||
# last error message, useful when early stopping and period are used together.
|
# last error message, useful when early stopping and period are used together.
|
||||||
self._latest = None
|
self._latest: Optional[str] = None
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _fmt_metric(self, data, metric, score, std):
|
def _fmt_metric(self, data, metric, score, std) -> str:
|
||||||
if std is not None and self.show_stdv:
|
if std is not None and self.show_stdv:
|
||||||
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
|
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
|
||||||
else:
|
else:
|
||||||
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
|
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, evals_log):
|
def after_iteration(self, model, epoch: int,
|
||||||
|
evals_log: CallbackContainer.EvalsLog) -> bool:
|
||||||
if not evals_log:
|
if not evals_log:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
msg = f'[{epoch}]'
|
msg: str = f'[{epoch}]'
|
||||||
if rabit.get_rank() == self.printer_rank:
|
if rabit.get_rank() == self.printer_rank:
|
||||||
for data, metric in evals_log.items():
|
for data, metric in evals_log.items():
|
||||||
for metric_name, log in metric.items():
|
for metric_name, log in metric.items():
|
||||||
|
stdv: Optional[float] = None
|
||||||
if isinstance(log[-1], tuple):
|
if isinstance(log[-1], tuple):
|
||||||
score = log[-1][0]
|
score = log[-1][0]
|
||||||
stdv = log[-1][1]
|
stdv = log[-1][1]
|
||||||
else:
|
else:
|
||||||
score = log[-1]
|
score = log[-1]
|
||||||
stdv = None
|
|
||||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||||
msg += '\n'
|
msg += '\n'
|
||||||
|
|
||||||
@ -665,7 +683,8 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
self._epoch = 0
|
self._epoch = 0
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, evals_log):
|
def after_iteration(self, model, epoch: int,
|
||||||
|
evals_log: CallbackContainer.EvalsLog) -> bool:
|
||||||
if self._epoch == self._iterations:
|
if self._epoch == self._iterations:
|
||||||
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
||||||
('.pkl' if self._as_pickle else '.json'))
|
('.pkl' if self._as_pickle else '.json'))
|
||||||
@ -677,6 +696,7 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
else:
|
else:
|
||||||
model.save_model(path)
|
model.save_model(path)
|
||||||
self._epoch += 1
|
self._epoch += 1
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LegacyCallbacks:
|
class LegacyCallbacks:
|
||||||
|
|||||||
@ -1177,23 +1177,6 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
return self.__copy__()
|
return self.__copy__()
|
||||||
|
|
||||||
def load_rabit_checkpoint(self):
|
|
||||||
"""Initialize the model by load from rabit checkpoint.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
version: integer
|
|
||||||
The version number of the model.
|
|
||||||
"""
|
|
||||||
version = ctypes.c_int()
|
|
||||||
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
|
|
||||||
self.handle, ctypes.byref(version)))
|
|
||||||
return version.value
|
|
||||||
|
|
||||||
def save_rabit_checkpoint(self):
|
|
||||||
"""Save the current booster to rabit checkpoint."""
|
|
||||||
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))
|
|
||||||
|
|
||||||
def attr(self, key):
|
def attr(self, key):
|
||||||
"""Get attribute string from the Booster.
|
"""Get attribute string from the Booster.
|
||||||
|
|
||||||
@ -1745,6 +1728,17 @@ class Booster(object):
|
|||||||
else:
|
else:
|
||||||
raise TypeError('Unknown file type: ', fname)
|
raise TypeError('Unknown file type: ', fname)
|
||||||
|
|
||||||
|
def num_boosted_rounds(self) -> int:
|
||||||
|
'''Get number of boosted rounds. For gblinear this is reset to 0 after
|
||||||
|
serializing the model.
|
||||||
|
|
||||||
|
'''
|
||||||
|
rounds = ctypes.c_int()
|
||||||
|
assert self.handle is not None
|
||||||
|
_check_call(_LIB.XGBoosterBoostedRounds(
|
||||||
|
self.handle, ctypes.byref(rounds)))
|
||||||
|
return rounds.value
|
||||||
|
|
||||||
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
|
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
|
||||||
"""Dump model into a text or JSON file. Unlike `save_model`, the
|
"""Dump model into a text or JSON file. Unlike `save_model`, the
|
||||||
output format is primarily used for visualization or interpretation,
|
output format is primarily used for visualization or interpretation,
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
|
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
||||||
from .training import train
|
from .training import train
|
||||||
@ -494,6 +495,22 @@ class XGBModel(XGBModelBase):
|
|||||||
# Delete the attribute after load
|
# Delete the attribute after load
|
||||||
self.get_booster().set_attr(scikit_learn=None)
|
self.get_booster().set_attr(scikit_learn=None)
|
||||||
|
|
||||||
|
def _configure_fit(
|
||||||
|
self,
|
||||||
|
booster: Optional[Booster],
|
||||||
|
eval_metric: Optional[Union[Callable, str, List[str]]],
|
||||||
|
params: Dict[str, Any],
|
||||||
|
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
|
||||||
|
model = self._Booster if hasattr(self, "_Booster") else None
|
||||||
|
model = booster if booster is not None else model
|
||||||
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
|
if eval_metric is not None:
|
||||||
|
if callable(eval_metric):
|
||||||
|
eval_metric = None
|
||||||
|
else:
|
||||||
|
params.update({"eval_metric": eval_metric})
|
||||||
|
return model, feval, params
|
||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(self, X, y, *, sample_weight=None, base_margin=None,
|
def fit(self, X, y, *, sample_weight=None, base_margin=None,
|
||||||
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
||||||
@ -586,19 +603,13 @@ class XGBModel(XGBModelBase):
|
|||||||
else:
|
else:
|
||||||
obj = None
|
obj = None
|
||||||
|
|
||||||
feval = eval_metric if callable(eval_metric) else None
|
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
||||||
if eval_metric is not None:
|
|
||||||
if callable(eval_metric):
|
|
||||||
eval_metric = None
|
|
||||||
else:
|
|
||||||
params.update({'eval_metric': eval_metric})
|
|
||||||
|
|
||||||
self._Booster = train(params, train_dmatrix,
|
self._Booster = train(params, train_dmatrix,
|
||||||
self.get_num_boosting_rounds(), evals=evals,
|
self.get_num_boosting_rounds(), evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result,
|
evals_result=evals_result,
|
||||||
obj=obj, feval=feval,
|
obj=obj, feval=feval,
|
||||||
verbose_eval=verbose, xgb_model=xgb_model,
|
verbose_eval=verbose, xgb_model=model,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
||||||
if evals_result:
|
if evals_result:
|
||||||
@ -857,27 +868,20 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
not np.array_equal(self.classes_, np.arange(self.n_classes_))):
|
not np.array_equal(self.classes_, np.arange(self.n_classes_))):
|
||||||
raise ValueError(label_encoding_check_error)
|
raise ValueError(label_encoding_check_error)
|
||||||
|
|
||||||
xgb_options = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
if callable(self.objective):
|
if callable(self.objective):
|
||||||
obj = _objective_decorator(self.objective)
|
obj = _objective_decorator(self.objective)
|
||||||
# Use default value. Is it really not used ?
|
# Use default value. Is it really not used ?
|
||||||
xgb_options["objective"] = "binary:logistic"
|
params["objective"] = "binary:logistic"
|
||||||
else:
|
else:
|
||||||
obj = None
|
obj = None
|
||||||
|
|
||||||
if self.n_classes_ > 2:
|
if self.n_classes_ > 2:
|
||||||
# Switch to using a multiclass objective in the underlying
|
# Switch to using a multiclass objective in the underlying
|
||||||
# XGB instance
|
# XGB instance
|
||||||
xgb_options['objective'] = 'multi:softprob'
|
params['objective'] = 'multi:softprob'
|
||||||
xgb_options['num_class'] = self.n_classes_
|
params['num_class'] = self.n_classes_
|
||||||
|
|
||||||
feval = eval_metric if callable(eval_metric) else None
|
|
||||||
if eval_metric is not None:
|
|
||||||
if callable(eval_metric):
|
|
||||||
eval_metric = None
|
|
||||||
else:
|
|
||||||
xgb_options.update({"eval_metric": eval_metric})
|
|
||||||
|
|
||||||
if self.use_label_encoder:
|
if self.use_label_encoder:
|
||||||
if not can_use_label_encoder:
|
if not can_use_label_encoder:
|
||||||
@ -891,6 +895,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
else:
|
else:
|
||||||
label_transform = (lambda x: x)
|
label_transform = (lambda x: x)
|
||||||
|
|
||||||
|
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
||||||
if len(X.shape) != 2:
|
if len(X.shape) != 2:
|
||||||
# Simply raise an error here since there might be many
|
# Simply raise an error here since there might be many
|
||||||
# different ways of reshaping
|
# different ways of reshaping
|
||||||
@ -906,15 +911,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set,
|
eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set,
|
||||||
eval_group=None, label_transform=label_transform)
|
eval_group=None, label_transform=label_transform)
|
||||||
|
|
||||||
self._Booster = train(xgb_options, train_dmatrix,
|
self._Booster = train(params, train_dmatrix,
|
||||||
self.get_num_boosting_rounds(),
|
self.get_num_boosting_rounds(),
|
||||||
evals=evals,
|
evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, obj=obj, feval=feval,
|
evals_result=evals_result, obj=obj, feval=feval,
|
||||||
verbose_eval=verbose, xgb_model=xgb_model,
|
verbose_eval=verbose, xgb_model=model,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
||||||
self.objective = xgb_options["objective"]
|
self.objective = params["objective"]
|
||||||
if evals_result:
|
if evals_result:
|
||||||
for val in evals_result.items():
|
for val in evals_result.items():
|
||||||
evals_result_key = list(val[1].keys())[0]
|
evals_result_key = list(val[1].keys())[0]
|
||||||
|
|||||||
@ -4,11 +4,10 @@
|
|||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
import warnings
|
import warnings
|
||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, XGBoostError
|
from .core import Booster, XGBoostError
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||||
from . import rabit
|
|
||||||
from . import callback
|
from . import callback
|
||||||
|
|
||||||
|
|
||||||
@ -51,28 +50,12 @@ def _train_internal(params, dtrain,
|
|||||||
evals = list(evals)
|
evals = list(evals)
|
||||||
|
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
nboost = 0
|
|
||||||
num_parallel_tree = 1
|
|
||||||
|
|
||||||
if xgb_model is not None:
|
if xgb_model is not None:
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals],
|
bst = Booster(params, [dtrain] + [d[0] for d in evals],
|
||||||
model_file=xgb_model)
|
model_file=xgb_model)
|
||||||
nboost = len(bst.get_dump())
|
|
||||||
|
|
||||||
_params = dict(params) if isinstance(params, list) else params
|
start_iteration = 0
|
||||||
|
|
||||||
if 'num_parallel_tree' in _params and _params[
|
|
||||||
'num_parallel_tree'] is not None:
|
|
||||||
num_parallel_tree = _params['num_parallel_tree']
|
|
||||||
nboost //= num_parallel_tree
|
|
||||||
if 'num_class' in _params and _params['num_class'] is not None:
|
|
||||||
nboost //= _params['num_class']
|
|
||||||
|
|
||||||
# Distributed code: Load the checkpoint from rabit.
|
|
||||||
version = bst.load_rabit_checkpoint()
|
|
||||||
assert rabit.get_world_size() != 1 or version == 0
|
|
||||||
start_iteration = int(version / 2)
|
|
||||||
nboost += start_iteration
|
|
||||||
|
|
||||||
is_new_callback = _is_new_callback(callbacks)
|
is_new_callback = _is_new_callback(callbacks)
|
||||||
if is_new_callback:
|
if is_new_callback:
|
||||||
@ -92,26 +75,13 @@ def _train_internal(params, dtrain,
|
|||||||
show_stdv=False, cvfolds=None)
|
show_stdv=False, cvfolds=None)
|
||||||
|
|
||||||
bst = callbacks.before_training(bst)
|
bst = callbacks.before_training(bst)
|
||||||
|
|
||||||
for i in range(start_iteration, num_boost_round):
|
for i in range(start_iteration, num_boost_round):
|
||||||
if callbacks.before_iteration(bst, i, dtrain, evals):
|
if callbacks.before_iteration(bst, i, dtrain, evals):
|
||||||
break
|
break
|
||||||
# Distributed code: need to resume to this point.
|
bst.update(dtrain, i, obj)
|
||||||
# Skip the first update if it is a recovery step.
|
|
||||||
if version % 2 == 0:
|
|
||||||
bst.update(dtrain, i, obj)
|
|
||||||
bst.save_rabit_checkpoint()
|
|
||||||
version += 1
|
|
||||||
|
|
||||||
assert rabit.get_world_size() == 1 or version == rabit.version_number()
|
|
||||||
|
|
||||||
nboost += 1
|
|
||||||
# check evaluation result.
|
|
||||||
if callbacks.after_iteration(bst, i, dtrain, evals):
|
if callbacks.after_iteration(bst, i, dtrain, evals):
|
||||||
break
|
break
|
||||||
# do checkpoint after evaluation, in case evaluation also updates
|
|
||||||
# booster.
|
|
||||||
bst.save_rabit_checkpoint()
|
|
||||||
version += 1
|
|
||||||
|
|
||||||
bst = callbacks.after_training(bst)
|
bst = callbacks.after_training(bst)
|
||||||
|
|
||||||
@ -122,7 +92,12 @@ def _train_internal(params, dtrain,
|
|||||||
bst.best_score = float(bst.attr('best_score'))
|
bst.best_score = float(bst.attr('best_score'))
|
||||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||||
else:
|
else:
|
||||||
bst.best_iteration = nboost - 1
|
bst.best_iteration = bst.num_boosted_rounds() - 1
|
||||||
|
try:
|
||||||
|
num_parallel_tree = int(json.loads(bst.save_config())['learner'][
|
||||||
|
'gradient_booster']['gbtree_train_param']['num_parallel_tree'])
|
||||||
|
except KeyError: # gblinear
|
||||||
|
num_parallel_tree = 1
|
||||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||||
# Copy to serialise and unserialise booster to reset state and free
|
# Copy to serialise and unserialise booster to reset state and free
|
||||||
# training memory
|
# training memory
|
||||||
@ -234,7 +209,7 @@ class CVPack(object):
|
|||||||
|
|
||||||
|
|
||||||
class _PackedBooster:
|
class _PackedBooster:
|
||||||
def __init__(self, cvfolds):
|
def __init__(self, cvfolds) -> None:
|
||||||
self.cvfolds = cvfolds
|
self.cvfolds = cvfolds
|
||||||
|
|
||||||
def update(self, iteration, obj):
|
def update(self, iteration, obj):
|
||||||
@ -262,6 +237,10 @@ class _PackedBooster:
|
|||||||
ret = self.cvfolds[0].bst.attr('best_iteration')
|
ret = self.cvfolds[0].bst.attr('best_iteration')
|
||||||
return int(ret)
|
return int(ret)
|
||||||
|
|
||||||
|
def num_boosted_rounds(self) -> int:
|
||||||
|
'''Number of boosted rounds.'''
|
||||||
|
return self.cvfolds[0].bst.num_boosted_rounds()
|
||||||
|
|
||||||
|
|
||||||
def groups_to_rows(groups, boundaries):
|
def groups_to_rows(groups, boundaries):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -502,6 +502,14 @@ XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out) {
|
||||||
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
|
static_cast<Learner*>(handle)->Configure();
|
||||||
|
*out = static_cast<Learner*>(handle)->BoostedRounds();
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
|
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
|||||||
@ -73,6 +73,10 @@ class GBLinear : public GradientBooster {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t BoostedRounds() const override {
|
||||||
|
return model_.num_boosted_rounds;
|
||||||
|
}
|
||||||
|
|
||||||
void Load(dmlc::Stream* fi) override {
|
void Load(dmlc::Stream* fi) override {
|
||||||
model_.Load(fi);
|
model_.Load(fi);
|
||||||
}
|
}
|
||||||
@ -122,7 +126,7 @@ class GBLinear : public GradientBooster {
|
|||||||
if (!this->CheckConvergence()) {
|
if (!this->CheckConvergence()) {
|
||||||
updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_);
|
updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_);
|
||||||
}
|
}
|
||||||
|
model_.num_boosted_rounds++;
|
||||||
monitor_.Stop("DoBoost");
|
monitor_.Stop("DoBoost");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -44,11 +44,12 @@ class GBLinearModel : public Model {
|
|||||||
DeprecatedGBLinearModelParam param_;
|
DeprecatedGBLinearModelParam param_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
int32_t num_boosted_rounds;
|
||||||
LearnerModelParam const* learner_model_param;
|
LearnerModelParam const* learner_model_param;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit GBLinearModel(LearnerModelParam const* learner_model_param) :
|
explicit GBLinearModel(LearnerModelParam const* learner_model_param) :
|
||||||
learner_model_param {learner_model_param} {}
|
num_boosted_rounds{0}, learner_model_param {learner_model_param} {}
|
||||||
void Configure(Args const &) { }
|
void Configure(Args const &) { }
|
||||||
|
|
||||||
// weight for each of feature, bias is the last one
|
// weight for each of feature, bias is the last one
|
||||||
|
|||||||
@ -249,10 +249,17 @@ class GBTree : public GradientBooster {
|
|||||||
auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree;
|
auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree;
|
||||||
return n_trees;
|
return n_trees;
|
||||||
}
|
}
|
||||||
|
|
||||||
// slice the trees, out must be already allocated
|
// slice the trees, out must be already allocated
|
||||||
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||||
GradientBooster *out, bool* out_of_bound) const override;
|
GradientBooster *out, bool* out_of_bound) const override;
|
||||||
|
|
||||||
|
int32_t BoostedRounds() const override {
|
||||||
|
CHECK_NE(tparam_.num_parallel_tree, 0);
|
||||||
|
CHECK_NE(model_.learner_model_param->num_output_group, 0);
|
||||||
|
return model_.trees.size() / this->LayerTrees();
|
||||||
|
}
|
||||||
|
|
||||||
void PredictBatch(DMatrix* p_fmat,
|
void PredictBatch(DMatrix* p_fmat,
|
||||||
PredictionCacheEntry* out_preds,
|
PredictionCacheEntry* out_preds,
|
||||||
bool training,
|
bool training,
|
||||||
|
|||||||
@ -1107,6 +1107,12 @@ class LearnerImpl : public LearnerIO {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t BoostedRounds() const override {
|
||||||
|
if (!this->gbm_) { return 0; } // haven't call train or LoadModel.
|
||||||
|
CHECK(!this->need_configuration_);
|
||||||
|
return this->gbm_->BoostedRounds();
|
||||||
|
}
|
||||||
|
|
||||||
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
|
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
|
||||||
return (*LearnerAPIThreadLocalStore::Get())[this];
|
return (*LearnerAPIThreadLocalStore::Get())[this];
|
||||||
}
|
}
|
||||||
|
|||||||
@ -124,6 +124,20 @@ class TestModels:
|
|||||||
predt_2 = bst.predict(dtrain)
|
predt_2 = bst.predict(dtrain)
|
||||||
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
||||||
|
|
||||||
|
def test_boost_from_existing_model(self):
|
||||||
|
X = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4)
|
||||||
|
assert booster.num_boosted_rounds() == 4
|
||||||
|
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4,
|
||||||
|
xgb_model=booster)
|
||||||
|
assert booster.num_boosted_rounds() == 8
|
||||||
|
booster = xgb.train({'updater': 'prune', 'process_type': 'update'}, X,
|
||||||
|
num_boost_round=4, xgb_model=booster)
|
||||||
|
# Trees are moved for update, the rounds is reduced. This test is
|
||||||
|
# written for being compatible with current code (1.0.0). If the
|
||||||
|
# behaviour is considered sub-optimal, feel free to change.
|
||||||
|
assert booster.num_boosted_rounds() == 4
|
||||||
|
|
||||||
def test_custom_objective(self):
|
def test_custom_objective(self):
|
||||||
param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'}
|
param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'}
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
|||||||
@ -43,7 +43,7 @@ class TestCallbacks:
|
|||||||
# Should print info by each period additionaly to first and latest iteration
|
# Should print info by each period additionaly to first and latest iteration
|
||||||
num_periods = rounds // int(verbose_eval)
|
num_periods = rounds // int(verbose_eval)
|
||||||
# Extra information is required for latest iteration
|
# Extra information is required for latest iteration
|
||||||
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
|
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
|
||||||
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
|
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
|
||||||
|
|
||||||
def test_evaluation_monitor(self):
|
def test_evaluation_monitor(self):
|
||||||
@ -63,7 +63,7 @@ class TestCallbacks:
|
|||||||
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
|
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
|
||||||
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
|
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
|
||||||
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
|
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
|
||||||
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
|
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
|
||||||
|
|
||||||
def test_early_stopping(self):
|
def test_early_stopping(self):
|
||||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
@ -81,6 +81,15 @@ class TestCallbacks:
|
|||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
|
# No early stopping, best_iteration should be set to last epoch
|
||||||
|
booster = xgb.train({'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'error'}, D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
num_boost_round=10,
|
||||||
|
evals_result=evals_result,
|
||||||
|
verbose_eval=True)
|
||||||
|
assert booster.num_boosted_rounds() - 1 == booster.best_iteration
|
||||||
|
|
||||||
def test_early_stopping_custom_eval(self):
|
def test_early_stopping_custom_eval(self):
|
||||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||||
@ -153,7 +162,7 @@ class TestCallbacks:
|
|||||||
eval_metric=tm.eval_error_metric, callbacks=[early_stop])
|
eval_metric=tm.eval_error_metric, callbacks=[early_stop])
|
||||||
booster = cls.get_booster()
|
booster = cls.get_booster()
|
||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(dump) == booster.best_iteration
|
assert len(dump) == booster.best_iteration + 1
|
||||||
|
|
||||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
save_best=True)
|
save_best=True)
|
||||||
@ -170,6 +179,32 @@ class TestCallbacks:
|
|||||||
eval_metric=tm.eval_error_metric,
|
eval_metric=tm.eval_error_metric,
|
||||||
callbacks=[early_stop])
|
callbacks=[early_stop])
|
||||||
|
|
||||||
|
def test_early_stopping_continuation(self):
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
cls = xgb.XGBClassifier()
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
|
save_best=True)
|
||||||
|
cls.fit(X, y, eval_set=[(X, y)],
|
||||||
|
eval_metric=tm.eval_error_metric,
|
||||||
|
callbacks=[early_stop])
|
||||||
|
booster = cls.get_booster()
|
||||||
|
assert booster.num_boosted_rounds() == booster.best_iteration + 1
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
path = os.path.join(tmpdir, 'model.json')
|
||||||
|
cls.save_model(path)
|
||||||
|
cls = xgb.XGBClassifier()
|
||||||
|
cls.load_model(path)
|
||||||
|
assert cls._Booster is not None
|
||||||
|
early_stopping_rounds = 3
|
||||||
|
cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric,
|
||||||
|
early_stopping_rounds=early_stopping_rounds)
|
||||||
|
booster = cls.get_booster()
|
||||||
|
assert booster.num_boosted_rounds() == \
|
||||||
|
booster.best_iteration + early_stopping_rounds + 1
|
||||||
|
|
||||||
def run_eta_decay(self, tree_method, deprecated_callback):
|
def run_eta_decay(self, tree_method, deprecated_callback):
|
||||||
if deprecated_callback:
|
if deprecated_callback:
|
||||||
scheduler = xgb.callback.reset_learning_rate
|
scheduler = xgb.callback.reset_learning_rate
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class TestEarlyStopping:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def assert_metrics_length(cv, expected_length):
|
def assert_metrics_length(cv, expected_length):
|
||||||
for key, value in cv.items():
|
for key, value in cv.items():
|
||||||
assert len(value) == expected_length
|
assert len(value) == expected_length
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_cv_early_stopping(self):
|
def test_cv_early_stopping(self):
|
||||||
|
|||||||
@ -62,6 +62,8 @@ def test_multiclass_classification():
|
|||||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||||
for train_index, test_index in kf.split(X, y):
|
for train_index, test_index in kf.split(X, y):
|
||||||
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
|
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
|
||||||
|
assert (xgb_model.get_booster().num_boosted_rounds() ==
|
||||||
|
xgb_model.n_estimators)
|
||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
# test other params in XGBClassifier().fit
|
# test other params in XGBClassifier().fit
|
||||||
preds2 = xgb_model.predict(X[test_index], output_margin=True,
|
preds2 = xgb_model.predict(X[test_index], output_margin=True,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user