Support early stopping with training continuation, correct num boosted rounds. (#6506)

* Implement early stopping with training continuation.

* Add new C API for obtaining boosted rounds.

* Fix off by 1 in `save_best`.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2020-12-17 19:59:19 +08:00 committed by GitHub
parent 125b3c0f2d
commit ca3da55de4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 210 additions and 118 deletions

View File

@ -614,6 +614,15 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
BoosterHandle *out);
/*!
* \brief Get number of boosted rounds from gradient booster. When process_type is
* update, this number might drop due to removed tree.
* \param handle Handle to booster.
* \param out Pointer to output integer.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out);
/*!
* \brief set parameters
* \param handle handle

View File

@ -79,6 +79,9 @@ class GradientBooster : public Model, public Configurable {
virtual bool AllowLazyCheckPoint() const {
return false;
}
/*! \brief Return number of boosted rounds.
*/
virtual int32_t BoostedRounds() const = 0;
/*!
* \brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features

View File

@ -134,6 +134,11 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) = 0;
/*
* \brief Get number of boosted rounds from gradient booster.
*/
virtual int32_t BoostedRounds() const = 0;
void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0;

View File

@ -6,7 +6,7 @@ from abc import ABC
import collections
import os
import pickle
from typing import Callable, List
from typing import Callable, List, Optional, Union, Dict, Tuple
import numpy
from . import rabit
@ -285,11 +285,13 @@ class TrainingCallback(ABC):
'''Run after training is finished.'''
return model
def before_iteration(self, model, epoch, evals_log):
def before_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
'''Run before each iteration. Return True when training should stop.'''
return False
def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
'''Run after each iteration. Return True when training should stop.'''
return False
@ -346,8 +348,13 @@ class CallbackContainer:
.. versionadded:: 1.3.0
'''
def __init__(self, callbacks: List[TrainingCallback],
metric: Callable = None, is_cv: bool = False):
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
def __init__(self,
callbacks: List[TrainingCallback],
metric: Callable = None,
is_cv: bool = False):
self.callbacks = set(callbacks)
if metric is not None:
msg = 'metric must be callable object for monitoring. For ' + \
@ -355,7 +362,7 @@ class CallbackContainer:
' will invoke monitor automatically.'
assert callable(metric), msg
self.metric = metric
self.history = collections.OrderedDict()
self.history: CallbackContainer.EvalsLog = collections.OrderedDict()
self.is_cv = is_cv
if self.is_cv:
@ -383,7 +390,7 @@ class CallbackContainer:
assert isinstance(model, Booster), msg
return model
def before_iteration(self, model, epoch, dtrain, evals):
def before_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called before training iteration.'''
return any(c.before_iteration(model, epoch, self.history)
for c in self.callbacks)
@ -409,7 +416,7 @@ class CallbackContainer:
self.history[data_name][metric_name] = [s]
return False
def after_iteration(self, model, epoch, dtrain, evals):
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called after training iteration.'''
if self.is_cv:
scores = model.eval(epoch, self.metric)
@ -445,7 +452,7 @@ class LearningRateScheduler(TrainingCallback):
rounds.
'''
def __init__(self, learning_rates):
def __init__(self, learning_rates) -> None:
assert callable(learning_rates) or \
isinstance(learning_rates, collections.abc.Sequence)
if callable(learning_rates):
@ -454,41 +461,42 @@ class LearningRateScheduler(TrainingCallback):
self.learning_rates = lambda epoch: learning_rates[epoch]
super().__init__()
def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch, evals_log) -> bool:
model.set_param('learning_rate', self.learning_rates(epoch))
return False
# pylint: disable=too-many-instance-attributes
class EarlyStopping(TrainingCallback):
''' Callback function for early stopping
"""Callback function for early stopping
.. versionadded:: 1.3.0
Parameters
----------
rounds : int
rounds
Early stopping rounds.
metric_name : str
metric_name
Name of metric that is used for early stopping.
data_name: str
data_name
Name of dataset that is used for early stopping.
maximize : bool
maximize
Whether to maximize evaluation metric. None means auto (discouraged).
save_best : bool
save_best
Whether training should return the best model or the last model.
'''
"""
def __init__(self,
rounds,
metric_name=None,
data_name=None,
maximize=None,
save_best=False):
rounds: int,
metric_name: Optional[str] = None,
data_name: Optional[str] = None,
maximize: Optional[bool] = None,
save_best: Optional[bool] = False) -> None:
self.data = data_name
self.metric_name = metric_name
self.rounds = rounds
self.save_best = save_best
self.maximize = maximize
self.stopping_history = {}
self.stopping_history: CallbackContainer.EvalsLog = {}
if self.maximize is not None:
if self.maximize:
@ -496,11 +504,16 @@ class EarlyStopping(TrainingCallback):
else:
self.improve_op = lambda x, y: x < y
self.current_rounds = 0
self.best_scores = {}
self.current_rounds: int = 0
self.best_scores: dict = {}
self.starting_round: int = 0
super().__init__()
def _update_rounds(self, score, name, metric, model, epoch):
def before_training(self, model):
self.starting_round = model.num_boosted_rounds()
return model
def _update_rounds(self, score, name, metric, model, epoch) -> bool:
# Just to be compatibility with old behavior before 1.3. We should let
# user to decide.
if self.maximize is None:
@ -536,7 +549,9 @@ class EarlyStopping(TrainingCallback):
return True
return False
def after_iteration(self, model: Booster, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
epoch += self.starting_round # training continuation
msg = 'Must have at least 1 validation dataset for early stopping.'
assert len(evals_log.keys()) >= 1, msg
data_name = ''
@ -562,12 +577,14 @@ class EarlyStopping(TrainingCallback):
score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch)
def after_training(self, model: Booster):
def after_training(self, model):
try:
if self.save_best:
model = model[: int(model.attr('best_iteration'))]
model = model[: int(model.attr("best_iteration")) + 1]
except XGBoostError as e:
raise XGBoostError('`save_best` is not applicable to current booster') from e
raise XGBoostError(
"`save_best` is not applicable to current booster"
) from e
return model
@ -588,36 +605,37 @@ class EvaluationMonitor(TrainingCallback):
show_stdv : bool
Used in cv to show standard deviation. Users should not specify it.
'''
def __init__(self, rank=0, period=1, show_stdv=False):
def __init__(self, rank=0, period=1, show_stdv=False) -> None:
self.printer_rank = rank
self.show_stdv = show_stdv
self.period = period
assert period > 0
# last error message, useful when early stopping and period are used together.
self._latest = None
self._latest: Optional[str] = None
super().__init__()
def _fmt_metric(self, data, metric, score, std):
def _fmt_metric(self, data, metric, score, std) -> str:
if std is not None and self.show_stdv:
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
else:
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
return msg
def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
if not evals_log:
return False
msg = f'[{epoch}]'
msg: str = f'[{epoch}]'
if rabit.get_rank() == self.printer_rank:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
stdv: Optional[float] = None
if isinstance(log[-1], tuple):
score = log[-1][0]
stdv = log[-1][1]
else:
score = log[-1]
stdv = None
msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n'
@ -665,7 +683,8 @@ class TrainingCheckPoint(TrainingCallback):
self._epoch = 0
super().__init__()
def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
if self._epoch == self._iterations:
path = os.path.join(self._path, self._name + '_' + str(epoch) +
('.pkl' if self._as_pickle else '.json'))
@ -677,6 +696,7 @@ class TrainingCheckPoint(TrainingCallback):
else:
model.save_model(path)
self._epoch += 1
return False
class LegacyCallbacks:

View File

@ -1177,23 +1177,6 @@ class Booster(object):
"""
return self.__copy__()
def load_rabit_checkpoint(self):
"""Initialize the model by load from rabit checkpoint.
Returns
-------
version: integer
The version number of the model.
"""
version = ctypes.c_int()
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
self.handle, ctypes.byref(version)))
return version.value
def save_rabit_checkpoint(self):
"""Save the current booster to rabit checkpoint."""
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))
def attr(self, key):
"""Get attribute string from the Booster.
@ -1745,6 +1728,17 @@ class Booster(object):
else:
raise TypeError('Unknown file type: ', fname)
def num_boosted_rounds(self) -> int:
'''Get number of boosted rounds. For gblinear this is reset to 0 after
serializing the model.
'''
rounds = ctypes.c_int()
assert self.handle is not None
_check_call(_LIB.XGBoosterBoostedRounds(
self.handle, ctypes.byref(rounds)))
return rounds.value
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
"""Dump model into a text or JSON file. Unlike `save_model`, the
output format is primarily used for visualization or interpretation,

View File

@ -4,6 +4,7 @@
import copy
import warnings
import json
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
import numpy as np
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
from .training import train
@ -494,6 +495,22 @@ class XGBModel(XGBModelBase):
# Delete the attribute after load
self.get_booster().set_attr(scikit_learn=None)
def _configure_fit(
self,
booster: Optional[Booster],
eval_metric: Optional[Union[Callable, str, List[str]]],
params: Dict[str, Any],
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
model = self._Booster if hasattr(self, "_Booster") else None
model = booster if booster is not None else model
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({"eval_metric": eval_metric})
return model, feval, params
@_deprecate_positional_args
def fit(self, X, y, *, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, early_stopping_rounds=None,
@ -586,19 +603,13 @@ class XGBModel(XGBModelBase):
else:
obj = None
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({'eval_metric': eval_metric})
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result,
obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
verbose_eval=verbose, xgb_model=model,
callbacks=callbacks)
if evals_result:
@ -857,27 +868,20 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
not np.array_equal(self.classes_, np.arange(self.n_classes_))):
raise ValueError(label_encoding_check_error)
xgb_options = self.get_xgb_params()
params = self.get_xgb_params()
if callable(self.objective):
obj = _objective_decorator(self.objective)
# Use default value. Is it really not used ?
xgb_options["objective"] = "binary:logistic"
params["objective"] = "binary:logistic"
else:
obj = None
if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying
# XGB instance
xgb_options['objective'] = 'multi:softprob'
xgb_options['num_class'] = self.n_classes_
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
xgb_options.update({"eval_metric": eval_metric})
params['objective'] = 'multi:softprob'
params['num_class'] = self.n_classes_
if self.use_label_encoder:
if not can_use_label_encoder:
@ -891,6 +895,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
else:
label_transform = (lambda x: x)
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
if len(X.shape) != 2:
# Simply raise an error here since there might be many
# different ways of reshaping
@ -906,15 +911,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set,
eval_group=None, label_transform=label_transform)
self._Booster = train(xgb_options, train_dmatrix,
self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(),
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
verbose_eval=verbose, xgb_model=model,
callbacks=callbacks)
self.objective = xgb_options["objective"]
self.objective = params["objective"]
if evals_result:
for val in evals_result.items():
evals_result_key = list(val[1].keys())[0]

View File

@ -4,11 +4,10 @@
"""Training Library containing training routines."""
import warnings
import copy
import json
import numpy as np
from .core import Booster, XGBoostError
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import rabit
from . import callback
@ -51,28 +50,12 @@ def _train_internal(params, dtrain,
evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals])
nboost = 0
num_parallel_tree = 1
if xgb_model is not None:
bst = Booster(params, [dtrain] + [d[0] for d in evals],
model_file=xgb_model)
nboost = len(bst.get_dump())
_params = dict(params) if isinstance(params, list) else params
if 'num_parallel_tree' in _params and _params[
'num_parallel_tree'] is not None:
num_parallel_tree = _params['num_parallel_tree']
nboost //= num_parallel_tree
if 'num_class' in _params and _params['num_class'] is not None:
nboost //= _params['num_class']
# Distributed code: Load the checkpoint from rabit.
version = bst.load_rabit_checkpoint()
assert rabit.get_world_size() != 1 or version == 0
start_iteration = int(version / 2)
nboost += start_iteration
start_iteration = 0
is_new_callback = _is_new_callback(callbacks)
if is_new_callback:
@ -92,26 +75,13 @@ def _train_internal(params, dtrain,
show_stdv=False, cvfolds=None)
bst = callbacks.before_training(bst)
for i in range(start_iteration, num_boost_round):
if callbacks.before_iteration(bst, i, dtrain, evals):
break
# Distributed code: need to resume to this point.
# Skip the first update if it is a recovery step.
if version % 2 == 0:
bst.update(dtrain, i, obj)
bst.save_rabit_checkpoint()
version += 1
assert rabit.get_world_size() == 1 or version == rabit.version_number()
nboost += 1
# check evaluation result.
bst.update(dtrain, i, obj)
if callbacks.after_iteration(bst, i, dtrain, evals):
break
# do checkpoint after evaluation, in case evaluation also updates
# booster.
bst.save_rabit_checkpoint()
version += 1
bst = callbacks.after_training(bst)
@ -122,7 +92,12 @@ def _train_internal(params, dtrain,
bst.best_score = float(bst.attr('best_score'))
bst.best_iteration = int(bst.attr('best_iteration'))
else:
bst.best_iteration = nboost - 1
bst.best_iteration = bst.num_boosted_rounds() - 1
try:
num_parallel_tree = int(json.loads(bst.save_config())['learner'][
'gradient_booster']['gbtree_train_param']['num_parallel_tree'])
except KeyError: # gblinear
num_parallel_tree = 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
# Copy to serialise and unserialise booster to reset state and free
# training memory
@ -234,7 +209,7 @@ class CVPack(object):
class _PackedBooster:
def __init__(self, cvfolds):
def __init__(self, cvfolds) -> None:
self.cvfolds = cvfolds
def update(self, iteration, obj):
@ -262,6 +237,10 @@ class _PackedBooster:
ret = self.cvfolds[0].bst.attr('best_iteration')
return int(ret)
def num_boosted_rounds(self) -> int:
'''Number of boosted rounds.'''
return self.cvfolds[0].bst.num_boosted_rounds()
def groups_to_rows(groups, boundaries):
"""

View File

@ -502,6 +502,14 @@ XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
API_END();
}
XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out) {
API_BEGIN();
CHECK_HANDLE();
static_cast<Learner*>(handle)->Configure();
*out = static_cast<Learner*>(handle)->BoostedRounds();
API_END();
}
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
API_BEGIN();
CHECK_HANDLE();

View File

@ -73,6 +73,10 @@ class GBLinear : public GradientBooster {
}
}
int32_t BoostedRounds() const override {
return model_.num_boosted_rounds;
}
void Load(dmlc::Stream* fi) override {
model_.Load(fi);
}
@ -122,7 +126,7 @@ class GBLinear : public GradientBooster {
if (!this->CheckConvergence()) {
updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_);
}
model_.num_boosted_rounds++;
monitor_.Stop("DoBoost");
}

View File

@ -44,11 +44,12 @@ class GBLinearModel : public Model {
DeprecatedGBLinearModelParam param_;
public:
int32_t num_boosted_rounds;
LearnerModelParam const* learner_model_param;
public:
explicit GBLinearModel(LearnerModelParam const* learner_model_param) :
learner_model_param {learner_model_param} {}
num_boosted_rounds{0}, learner_model_param {learner_model_param} {}
void Configure(Args const &) { }
// weight for each of feature, bias is the last one

View File

@ -249,10 +249,17 @@ class GBTree : public GradientBooster {
auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree;
return n_trees;
}
// slice the trees, out must be already allocated
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const override;
int32_t BoostedRounds() const override {
CHECK_NE(tparam_.num_parallel_tree, 0);
CHECK_NE(model_.learner_model_param->num_output_group, 0);
return model_.trees.size() / this->LayerTrees();
}
void PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* out_preds,
bool training,

View File

@ -1107,6 +1107,12 @@ class LearnerImpl : public LearnerIO {
}
}
int32_t BoostedRounds() const override {
if (!this->gbm_) { return 0; } // haven't call train or LoadModel.
CHECK(!this->need_configuration_);
return this->gbm_->BoostedRounds();
}
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
return (*LearnerAPIThreadLocalStore::Get())[this];
}

View File

@ -124,6 +124,20 @@ class TestModels:
predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
def test_boost_from_existing_model(self):
X = xgb.DMatrix(dpath + 'agaricus.txt.train')
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4)
assert booster.num_boosted_rounds() == 4
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4,
xgb_model=booster)
assert booster.num_boosted_rounds() == 8
booster = xgb.train({'updater': 'prune', 'process_type': 'update'}, X,
num_boost_round=4, xgb_model=booster)
# Trees are moved for update, the rounds is reduced. This test is
# written for being compatible with current code (1.0.0). If the
# behaviour is considered sub-optimal, feel free to change.
assert booster.num_boosted_rounds() == 4
def test_custom_objective(self):
param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]

View File

@ -43,7 +43,7 @@ class TestCallbacks:
# Should print info by each period additionaly to first and latest iteration
num_periods = rounds // int(verbose_eval)
# Extra information is required for latest iteration
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
def test_evaluation_monitor(self):
@ -63,7 +63,7 @@ class TestCallbacks:
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
def test_early_stopping(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
@ -81,6 +81,15 @@ class TestCallbacks:
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
# No early stopping, best_iteration should be set to last epoch
booster = xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=10,
evals_result=evals_result,
verbose_eval=True)
assert booster.num_boosted_rounds() - 1 == booster.best_iteration
def test_early_stopping_custom_eval(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
@ -153,7 +162,7 @@ class TestCallbacks:
eval_metric=tm.eval_error_metric, callbacks=[early_stop])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
assert len(dump) == booster.best_iteration
assert len(dump) == booster.best_iteration + 1
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)
@ -170,6 +179,32 @@ class TestCallbacks:
eval_metric=tm.eval_error_metric,
callbacks=[early_stop])
def test_early_stopping_continuation(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
cls = xgb.XGBClassifier()
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)
cls.fit(X, y, eval_set=[(X, y)],
eval_metric=tm.eval_error_metric,
callbacks=[early_stop])
booster = cls.get_booster()
assert booster.num_boosted_rounds() == booster.best_iteration + 1
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'model.json')
cls.save_model(path)
cls = xgb.XGBClassifier()
cls.load_model(path)
assert cls._Booster is not None
early_stopping_rounds = 3
cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric,
early_stopping_rounds=early_stopping_rounds)
booster = cls.get_booster()
assert booster.num_boosted_rounds() == \
booster.best_iteration + early_stopping_rounds + 1
def run_eta_decay(self, tree_method, deprecated_callback):
if deprecated_callback:
scheduler = xgb.callback.reset_learning_rate

View File

@ -46,7 +46,7 @@ class TestEarlyStopping:
@staticmethod
def assert_metrics_length(cv, expected_length):
for key, value in cv.items():
assert len(value) == expected_length
assert len(value) == expected_length
@pytest.mark.skipif(**tm.no_sklearn())
def test_cv_early_stopping(self):

View File

@ -62,6 +62,8 @@ def test_multiclass_classification():
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
assert (xgb_model.get_booster().num_boosted_rounds() ==
xgb_model.n_estimators)
preds = xgb_model.predict(X[test_index])
# test other params in XGBClassifier().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True,