Support early stopping with training continuation, correct num boosted rounds. (#6506)

* Implement early stopping with training continuation.

* Add new C API for obtaining boosted rounds.

* Fix off by 1 in `save_best`.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2020-12-17 19:59:19 +08:00 committed by GitHub
parent 125b3c0f2d
commit ca3da55de4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 210 additions and 118 deletions

View File

@ -614,6 +614,15 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step, int end_layer, int step,
BoosterHandle *out); BoosterHandle *out);
/*!
* \brief Get number of boosted rounds from gradient booster. When process_type is
* update, this number might drop due to removed tree.
* \param handle Handle to booster.
* \param out Pointer to output integer.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out);
/*! /*!
* \brief set parameters * \brief set parameters
* \param handle handle * \param handle handle

View File

@ -79,6 +79,9 @@ class GradientBooster : public Model, public Configurable {
virtual bool AllowLazyCheckPoint() const { virtual bool AllowLazyCheckPoint() const {
return false; return false;
} }
/*! \brief Return number of boosted rounds.
*/
virtual int32_t BoostedRounds() const = 0;
/*! /*!
* \brief perform update to the model(boosting) * \brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features * \param p_fmat feature matrix that provide access to features

View File

@ -134,6 +134,11 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
HostDeviceVector<bst_float> **out_preds, HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) = 0; uint32_t layer_begin, uint32_t layer_end) = 0;
/*
* \brief Get number of boosted rounds from gradient booster.
*/
virtual int32_t BoostedRounds() const = 0;
void LoadModel(Json const& in) override = 0; void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0; void SaveModel(Json* out) const override = 0;

View File

@ -6,7 +6,7 @@ from abc import ABC
import collections import collections
import os import os
import pickle import pickle
from typing import Callable, List from typing import Callable, List, Optional, Union, Dict, Tuple
import numpy import numpy
from . import rabit from . import rabit
@ -285,11 +285,13 @@ class TrainingCallback(ABC):
'''Run after training is finished.''' '''Run after training is finished.'''
return model return model
def before_iteration(self, model, epoch, evals_log): def before_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
'''Run before each iteration. Return True when training should stop.''' '''Run before each iteration. Return True when training should stop.'''
return False return False
def after_iteration(self, model, epoch, evals_log): def after_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
'''Run after each iteration. Return True when training should stop.''' '''Run after each iteration. Return True when training should stop.'''
return False return False
@ -346,8 +348,13 @@ class CallbackContainer:
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
''' '''
def __init__(self, callbacks: List[TrainingCallback],
metric: Callable = None, is_cv: bool = False): EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
def __init__(self,
callbacks: List[TrainingCallback],
metric: Callable = None,
is_cv: bool = False):
self.callbacks = set(callbacks) self.callbacks = set(callbacks)
if metric is not None: if metric is not None:
msg = 'metric must be callable object for monitoring. For ' + \ msg = 'metric must be callable object for monitoring. For ' + \
@ -355,7 +362,7 @@ class CallbackContainer:
' will invoke monitor automatically.' ' will invoke monitor automatically.'
assert callable(metric), msg assert callable(metric), msg
self.metric = metric self.metric = metric
self.history = collections.OrderedDict() self.history: CallbackContainer.EvalsLog = collections.OrderedDict()
self.is_cv = is_cv self.is_cv = is_cv
if self.is_cv: if self.is_cv:
@ -383,7 +390,7 @@ class CallbackContainer:
assert isinstance(model, Booster), msg assert isinstance(model, Booster), msg
return model return model
def before_iteration(self, model, epoch, dtrain, evals): def before_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called before training iteration.''' '''Function called before training iteration.'''
return any(c.before_iteration(model, epoch, self.history) return any(c.before_iteration(model, epoch, self.history)
for c in self.callbacks) for c in self.callbacks)
@ -409,7 +416,7 @@ class CallbackContainer:
self.history[data_name][metric_name] = [s] self.history[data_name][metric_name] = [s]
return False return False
def after_iteration(self, model, epoch, dtrain, evals): def after_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called after training iteration.''' '''Function called after training iteration.'''
if self.is_cv: if self.is_cv:
scores = model.eval(epoch, self.metric) scores = model.eval(epoch, self.metric)
@ -445,7 +452,7 @@ class LearningRateScheduler(TrainingCallback):
rounds. rounds.
''' '''
def __init__(self, learning_rates): def __init__(self, learning_rates) -> None:
assert callable(learning_rates) or \ assert callable(learning_rates) or \
isinstance(learning_rates, collections.abc.Sequence) isinstance(learning_rates, collections.abc.Sequence)
if callable(learning_rates): if callable(learning_rates):
@ -454,41 +461,42 @@ class LearningRateScheduler(TrainingCallback):
self.learning_rates = lambda epoch: learning_rates[epoch] self.learning_rates = lambda epoch: learning_rates[epoch]
super().__init__() super().__init__()
def after_iteration(self, model, epoch, evals_log): def after_iteration(self, model, epoch, evals_log) -> bool:
model.set_param('learning_rate', self.learning_rates(epoch)) model.set_param('learning_rate', self.learning_rates(epoch))
return False
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
class EarlyStopping(TrainingCallback): class EarlyStopping(TrainingCallback):
''' Callback function for early stopping """Callback function for early stopping
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
Parameters Parameters
---------- ----------
rounds : int rounds
Early stopping rounds. Early stopping rounds.
metric_name : str metric_name
Name of metric that is used for early stopping. Name of metric that is used for early stopping.
data_name: str data_name
Name of dataset that is used for early stopping. Name of dataset that is used for early stopping.
maximize : bool maximize
Whether to maximize evaluation metric. None means auto (discouraged). Whether to maximize evaluation metric. None means auto (discouraged).
save_best : bool save_best
Whether training should return the best model or the last model. Whether training should return the best model or the last model.
''' """
def __init__(self, def __init__(self,
rounds, rounds: int,
metric_name=None, metric_name: Optional[str] = None,
data_name=None, data_name: Optional[str] = None,
maximize=None, maximize: Optional[bool] = None,
save_best=False): save_best: Optional[bool] = False) -> None:
self.data = data_name self.data = data_name
self.metric_name = metric_name self.metric_name = metric_name
self.rounds = rounds self.rounds = rounds
self.save_best = save_best self.save_best = save_best
self.maximize = maximize self.maximize = maximize
self.stopping_history = {} self.stopping_history: CallbackContainer.EvalsLog = {}
if self.maximize is not None: if self.maximize is not None:
if self.maximize: if self.maximize:
@ -496,11 +504,16 @@ class EarlyStopping(TrainingCallback):
else: else:
self.improve_op = lambda x, y: x < y self.improve_op = lambda x, y: x < y
self.current_rounds = 0 self.current_rounds: int = 0
self.best_scores = {} self.best_scores: dict = {}
self.starting_round: int = 0
super().__init__() super().__init__()
def _update_rounds(self, score, name, metric, model, epoch): def before_training(self, model):
self.starting_round = model.num_boosted_rounds()
return model
def _update_rounds(self, score, name, metric, model, epoch) -> bool:
# Just to be compatibility with old behavior before 1.3. We should let # Just to be compatibility with old behavior before 1.3. We should let
# user to decide. # user to decide.
if self.maximize is None: if self.maximize is None:
@ -536,7 +549,9 @@ class EarlyStopping(TrainingCallback):
return True return True
return False return False
def after_iteration(self, model: Booster, epoch, evals_log): def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
epoch += self.starting_round # training continuation
msg = 'Must have at least 1 validation dataset for early stopping.' msg = 'Must have at least 1 validation dataset for early stopping.'
assert len(evals_log.keys()) >= 1, msg assert len(evals_log.keys()) >= 1, msg
data_name = '' data_name = ''
@ -562,12 +577,14 @@ class EarlyStopping(TrainingCallback):
score = data_log[metric_name][-1] score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch) return self._update_rounds(score, data_name, metric_name, model, epoch)
def after_training(self, model: Booster): def after_training(self, model):
try: try:
if self.save_best: if self.save_best:
model = model[: int(model.attr('best_iteration'))] model = model[: int(model.attr("best_iteration")) + 1]
except XGBoostError as e: except XGBoostError as e:
raise XGBoostError('`save_best` is not applicable to current booster') from e raise XGBoostError(
"`save_best` is not applicable to current booster"
) from e
return model return model
@ -588,36 +605,37 @@ class EvaluationMonitor(TrainingCallback):
show_stdv : bool show_stdv : bool
Used in cv to show standard deviation. Users should not specify it. Used in cv to show standard deviation. Users should not specify it.
''' '''
def __init__(self, rank=0, period=1, show_stdv=False): def __init__(self, rank=0, period=1, show_stdv=False) -> None:
self.printer_rank = rank self.printer_rank = rank
self.show_stdv = show_stdv self.show_stdv = show_stdv
self.period = period self.period = period
assert period > 0 assert period > 0
# last error message, useful when early stopping and period are used together. # last error message, useful when early stopping and period are used together.
self._latest = None self._latest: Optional[str] = None
super().__init__() super().__init__()
def _fmt_metric(self, data, metric, score, std): def _fmt_metric(self, data, metric, score, std) -> str:
if std is not None and self.show_stdv: if std is not None and self.show_stdv:
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std) msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
else: else:
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score) msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
return msg return msg
def after_iteration(self, model, epoch, evals_log): def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
if not evals_log: if not evals_log:
return False return False
msg = f'[{epoch}]' msg: str = f'[{epoch}]'
if rabit.get_rank() == self.printer_rank: if rabit.get_rank() == self.printer_rank:
for data, metric in evals_log.items(): for data, metric in evals_log.items():
for metric_name, log in metric.items(): for metric_name, log in metric.items():
stdv: Optional[float] = None
if isinstance(log[-1], tuple): if isinstance(log[-1], tuple):
score = log[-1][0] score = log[-1][0]
stdv = log[-1][1] stdv = log[-1][1]
else: else:
score = log[-1] score = log[-1]
stdv = None
msg += self._fmt_metric(data, metric_name, score, stdv) msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n' msg += '\n'
@ -665,7 +683,8 @@ class TrainingCheckPoint(TrainingCallback):
self._epoch = 0 self._epoch = 0
super().__init__() super().__init__()
def after_iteration(self, model, epoch, evals_log): def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
if self._epoch == self._iterations: if self._epoch == self._iterations:
path = os.path.join(self._path, self._name + '_' + str(epoch) + path = os.path.join(self._path, self._name + '_' + str(epoch) +
('.pkl' if self._as_pickle else '.json')) ('.pkl' if self._as_pickle else '.json'))
@ -677,6 +696,7 @@ class TrainingCheckPoint(TrainingCallback):
else: else:
model.save_model(path) model.save_model(path)
self._epoch += 1 self._epoch += 1
return False
class LegacyCallbacks: class LegacyCallbacks:

View File

@ -1177,23 +1177,6 @@ class Booster(object):
""" """
return self.__copy__() return self.__copy__()
def load_rabit_checkpoint(self):
"""Initialize the model by load from rabit checkpoint.
Returns
-------
version: integer
The version number of the model.
"""
version = ctypes.c_int()
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
self.handle, ctypes.byref(version)))
return version.value
def save_rabit_checkpoint(self):
"""Save the current booster to rabit checkpoint."""
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))
def attr(self, key): def attr(self, key):
"""Get attribute string from the Booster. """Get attribute string from the Booster.
@ -1745,6 +1728,17 @@ class Booster(object):
else: else:
raise TypeError('Unknown file type: ', fname) raise TypeError('Unknown file type: ', fname)
def num_boosted_rounds(self) -> int:
'''Get number of boosted rounds. For gblinear this is reset to 0 after
serializing the model.
'''
rounds = ctypes.c_int()
assert self.handle is not None
_check_call(_LIB.XGBoosterBoostedRounds(
self.handle, ctypes.byref(rounds)))
return rounds.value
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"): def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
"""Dump model into a text or JSON file. Unlike `save_model`, the """Dump model into a text or JSON file. Unlike `save_model`, the
output format is primarily used for visualization or interpretation, output format is primarily used for visualization or interpretation,

View File

@ -4,6 +4,7 @@
import copy import copy
import warnings import warnings
import json import json
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
import numpy as np import numpy as np
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
from .training import train from .training import train
@ -494,6 +495,22 @@ class XGBModel(XGBModelBase):
# Delete the attribute after load # Delete the attribute after load
self.get_booster().set_attr(scikit_learn=None) self.get_booster().set_attr(scikit_learn=None)
def _configure_fit(
self,
booster: Optional[Booster],
eval_metric: Optional[Union[Callable, str, List[str]]],
params: Dict[str, Any],
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
model = self._Booster if hasattr(self, "_Booster") else None
model = booster if booster is not None else model
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({"eval_metric": eval_metric})
return model, feval, params
@_deprecate_positional_args @_deprecate_positional_args
def fit(self, X, y, *, sample_weight=None, base_margin=None, def fit(self, X, y, *, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, early_stopping_rounds=None, eval_set=None, eval_metric=None, early_stopping_rounds=None,
@ -586,19 +603,13 @@ class XGBModel(XGBModelBase):
else: else:
obj = None obj = None
feval = eval_metric if callable(eval_metric) else None model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({'eval_metric': eval_metric})
self._Booster = train(params, train_dmatrix, self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(), evals=evals, self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, evals_result=evals_result,
obj=obj, feval=feval, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model, verbose_eval=verbose, xgb_model=model,
callbacks=callbacks) callbacks=callbacks)
if evals_result: if evals_result:
@ -857,27 +868,20 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
not np.array_equal(self.classes_, np.arange(self.n_classes_))): not np.array_equal(self.classes_, np.arange(self.n_classes_))):
raise ValueError(label_encoding_check_error) raise ValueError(label_encoding_check_error)
xgb_options = self.get_xgb_params() params = self.get_xgb_params()
if callable(self.objective): if callable(self.objective):
obj = _objective_decorator(self.objective) obj = _objective_decorator(self.objective)
# Use default value. Is it really not used ? # Use default value. Is it really not used ?
xgb_options["objective"] = "binary:logistic" params["objective"] = "binary:logistic"
else: else:
obj = None obj = None
if self.n_classes_ > 2: if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying # Switch to using a multiclass objective in the underlying
# XGB instance # XGB instance
xgb_options['objective'] = 'multi:softprob' params['objective'] = 'multi:softprob'
xgb_options['num_class'] = self.n_classes_ params['num_class'] = self.n_classes_
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
xgb_options.update({"eval_metric": eval_metric})
if self.use_label_encoder: if self.use_label_encoder:
if not can_use_label_encoder: if not can_use_label_encoder:
@ -891,6 +895,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
else: else:
label_transform = (lambda x: x) label_transform = (lambda x: x)
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
if len(X.shape) != 2: if len(X.shape) != 2:
# Simply raise an error here since there might be many # Simply raise an error here since there might be many
# different ways of reshaping # different ways of reshaping
@ -906,15 +911,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set, eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set,
eval_group=None, label_transform=label_transform) eval_group=None, label_transform=label_transform)
self._Booster = train(xgb_options, train_dmatrix, self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(), self.get_num_boosting_rounds(),
evals=evals, evals=evals,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval, evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model, verbose_eval=verbose, xgb_model=model,
callbacks=callbacks) callbacks=callbacks)
self.objective = xgb_options["objective"] self.objective = params["objective"]
if evals_result: if evals_result:
for val in evals_result.items(): for val in evals_result.items():
evals_result_key = list(val[1].keys())[0] evals_result_key = list(val[1].keys())[0]

View File

@ -4,11 +4,10 @@
"""Training Library containing training routines.""" """Training Library containing training routines."""
import warnings import warnings
import copy import copy
import json
import numpy as np import numpy as np
from .core import Booster, XGBoostError from .core import Booster, XGBoostError
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import rabit
from . import callback from . import callback
@ -51,28 +50,12 @@ def _train_internal(params, dtrain,
evals = list(evals) evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
nboost = 0
num_parallel_tree = 1
if xgb_model is not None: if xgb_model is not None:
bst = Booster(params, [dtrain] + [d[0] for d in evals], bst = Booster(params, [dtrain] + [d[0] for d in evals],
model_file=xgb_model) model_file=xgb_model)
nboost = len(bst.get_dump())
_params = dict(params) if isinstance(params, list) else params start_iteration = 0
if 'num_parallel_tree' in _params and _params[
'num_parallel_tree'] is not None:
num_parallel_tree = _params['num_parallel_tree']
nboost //= num_parallel_tree
if 'num_class' in _params and _params['num_class'] is not None:
nboost //= _params['num_class']
# Distributed code: Load the checkpoint from rabit.
version = bst.load_rabit_checkpoint()
assert rabit.get_world_size() != 1 or version == 0
start_iteration = int(version / 2)
nboost += start_iteration
is_new_callback = _is_new_callback(callbacks) is_new_callback = _is_new_callback(callbacks)
if is_new_callback: if is_new_callback:
@ -92,26 +75,13 @@ def _train_internal(params, dtrain,
show_stdv=False, cvfolds=None) show_stdv=False, cvfolds=None)
bst = callbacks.before_training(bst) bst = callbacks.before_training(bst)
for i in range(start_iteration, num_boost_round): for i in range(start_iteration, num_boost_round):
if callbacks.before_iteration(bst, i, dtrain, evals): if callbacks.before_iteration(bst, i, dtrain, evals):
break break
# Distributed code: need to resume to this point.
# Skip the first update if it is a recovery step.
if version % 2 == 0:
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
bst.save_rabit_checkpoint()
version += 1
assert rabit.get_world_size() == 1 or version == rabit.version_number()
nboost += 1
# check evaluation result.
if callbacks.after_iteration(bst, i, dtrain, evals): if callbacks.after_iteration(bst, i, dtrain, evals):
break break
# do checkpoint after evaluation, in case evaluation also updates
# booster.
bst.save_rabit_checkpoint()
version += 1
bst = callbacks.after_training(bst) bst = callbacks.after_training(bst)
@ -122,7 +92,12 @@ def _train_internal(params, dtrain,
bst.best_score = float(bst.attr('best_score')) bst.best_score = float(bst.attr('best_score'))
bst.best_iteration = int(bst.attr('best_iteration')) bst.best_iteration = int(bst.attr('best_iteration'))
else: else:
bst.best_iteration = nboost - 1 bst.best_iteration = bst.num_boosted_rounds() - 1
try:
num_parallel_tree = int(json.loads(bst.save_config())['learner'][
'gradient_booster']['gbtree_train_param']['num_parallel_tree'])
except KeyError: # gblinear
num_parallel_tree = 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
# Copy to serialise and unserialise booster to reset state and free # Copy to serialise and unserialise booster to reset state and free
# training memory # training memory
@ -234,7 +209,7 @@ class CVPack(object):
class _PackedBooster: class _PackedBooster:
def __init__(self, cvfolds): def __init__(self, cvfolds) -> None:
self.cvfolds = cvfolds self.cvfolds = cvfolds
def update(self, iteration, obj): def update(self, iteration, obj):
@ -262,6 +237,10 @@ class _PackedBooster:
ret = self.cvfolds[0].bst.attr('best_iteration') ret = self.cvfolds[0].bst.attr('best_iteration')
return int(ret) return int(ret)
def num_boosted_rounds(self) -> int:
'''Number of boosted rounds.'''
return self.cvfolds[0].bst.num_boosted_rounds()
def groups_to_rows(groups, boundaries): def groups_to_rows(groups, boundaries):
""" """

View File

@ -502,6 +502,14 @@ XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out) {
API_BEGIN();
CHECK_HANDLE();
static_cast<Learner*>(handle)->Configure();
*out = static_cast<Learner*>(handle)->BoostedRounds();
API_END();
}
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) { XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();

View File

@ -73,6 +73,10 @@ class GBLinear : public GradientBooster {
} }
} }
int32_t BoostedRounds() const override {
return model_.num_boosted_rounds;
}
void Load(dmlc::Stream* fi) override { void Load(dmlc::Stream* fi) override {
model_.Load(fi); model_.Load(fi);
} }
@ -122,7 +126,7 @@ class GBLinear : public GradientBooster {
if (!this->CheckConvergence()) { if (!this->CheckConvergence()) {
updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_); updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_);
} }
model_.num_boosted_rounds++;
monitor_.Stop("DoBoost"); monitor_.Stop("DoBoost");
} }

View File

@ -44,11 +44,12 @@ class GBLinearModel : public Model {
DeprecatedGBLinearModelParam param_; DeprecatedGBLinearModelParam param_;
public: public:
int32_t num_boosted_rounds;
LearnerModelParam const* learner_model_param; LearnerModelParam const* learner_model_param;
public: public:
explicit GBLinearModel(LearnerModelParam const* learner_model_param) : explicit GBLinearModel(LearnerModelParam const* learner_model_param) :
learner_model_param {learner_model_param} {} num_boosted_rounds{0}, learner_model_param {learner_model_param} {}
void Configure(Args const &) { } void Configure(Args const &) { }
// weight for each of feature, bias is the last one // weight for each of feature, bias is the last one

View File

@ -249,10 +249,17 @@ class GBTree : public GradientBooster {
auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree; auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree;
return n_trees; return n_trees;
} }
// slice the trees, out must be already allocated // slice the trees, out must be already allocated
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step, void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const override; GradientBooster *out, bool* out_of_bound) const override;
int32_t BoostedRounds() const override {
CHECK_NE(tparam_.num_parallel_tree, 0);
CHECK_NE(model_.learner_model_param->num_output_group, 0);
return model_.trees.size() / this->LayerTrees();
}
void PredictBatch(DMatrix* p_fmat, void PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* out_preds, PredictionCacheEntry* out_preds,
bool training, bool training,

View File

@ -1107,6 +1107,12 @@ class LearnerImpl : public LearnerIO {
} }
} }
int32_t BoostedRounds() const override {
if (!this->gbm_) { return 0; } // haven't call train or LoadModel.
CHECK(!this->need_configuration_);
return this->gbm_->BoostedRounds();
}
XGBAPIThreadLocalEntry& GetThreadLocal() const override { XGBAPIThreadLocalEntry& GetThreadLocal() const override {
return (*LearnerAPIThreadLocalStore::Get())[this]; return (*LearnerAPIThreadLocalStore::Get())[this];
} }

View File

@ -124,6 +124,20 @@ class TestModels:
predt_2 = bst.predict(dtrain) predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6) assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
def test_boost_from_existing_model(self):
X = xgb.DMatrix(dpath + 'agaricus.txt.train')
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4)
assert booster.num_boosted_rounds() == 4
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4,
xgb_model=booster)
assert booster.num_boosted_rounds() == 8
booster = xgb.train({'updater': 'prune', 'process_type': 'update'}, X,
num_boost_round=4, xgb_model=booster)
# Trees are moved for update, the rounds is reduced. This test is
# written for being compatible with current code (1.0.0). If the
# behaviour is considered sub-optimal, feel free to change.
assert booster.num_boosted_rounds() == 4
def test_custom_objective(self): def test_custom_objective(self):
param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'} param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'}
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]

View File

@ -81,6 +81,15 @@ class TestCallbacks:
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
# No early stopping, best_iteration should be set to last epoch
booster = xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=10,
evals_result=evals_result,
verbose_eval=True)
assert booster.num_boosted_rounds() - 1 == booster.best_iteration
def test_early_stopping_custom_eval(self): def test_early_stopping_custom_eval(self):
D_train = xgb.DMatrix(self.X_train, self.y_train) D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid) D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
@ -153,7 +162,7 @@ class TestCallbacks:
eval_metric=tm.eval_error_metric, callbacks=[early_stop]) eval_metric=tm.eval_error_metric, callbacks=[early_stop])
booster = cls.get_booster() booster = cls.get_booster()
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) == booster.best_iteration assert len(dump) == booster.best_iteration + 1
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True) save_best=True)
@ -170,6 +179,32 @@ class TestCallbacks:
eval_metric=tm.eval_error_metric, eval_metric=tm.eval_error_metric,
callbacks=[early_stop]) callbacks=[early_stop])
def test_early_stopping_continuation(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
cls = xgb.XGBClassifier()
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)
cls.fit(X, y, eval_set=[(X, y)],
eval_metric=tm.eval_error_metric,
callbacks=[early_stop])
booster = cls.get_booster()
assert booster.num_boosted_rounds() == booster.best_iteration + 1
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'model.json')
cls.save_model(path)
cls = xgb.XGBClassifier()
cls.load_model(path)
assert cls._Booster is not None
early_stopping_rounds = 3
cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric,
early_stopping_rounds=early_stopping_rounds)
booster = cls.get_booster()
assert booster.num_boosted_rounds() == \
booster.best_iteration + early_stopping_rounds + 1
def run_eta_decay(self, tree_method, deprecated_callback): def run_eta_decay(self, tree_method, deprecated_callback):
if deprecated_callback: if deprecated_callback:
scheduler = xgb.callback.reset_learning_rate scheduler = xgb.callback.reset_learning_rate

View File

@ -62,6 +62,8 @@ def test_multiclass_classification():
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y): for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index]) xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
assert (xgb_model.get_booster().num_boosted_rounds() ==
xgb_model.n_estimators)
preds = xgb_model.predict(X[test_index]) preds = xgb_model.predict(X[test_index])
# test other params in XGBClassifier().fit # test other params in XGBClassifier().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True, preds2 = xgb_model.predict(X[test_index], output_margin=True,