Define best_iteration only if early stopping is used. (#9403)

* Define `best_iteration` only if early stopping is used.

This is the behavior specified by the document but not honored in the actual code.

- Don't set the attributes if there's no early stopping.
- Clean up the code for callbacks, and replace assertions with proper exceptions.
- Assign the attributes when early stopping `save_best` is used.
- Turn the attributes into Python properties.

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2023-07-24 12:43:35 +08:00 committed by GitHub
parent 01e00efc53
commit 851cba931e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 249 additions and 179 deletions

View File

@ -1,9 +1,9 @@
''' """
Demo for using and defining callback functions Demo for using and defining callback functions
============================================== ==============================================
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
''' """
import argparse import argparse
import os import os
import tempfile import tempfile
@ -17,10 +17,11 @@ import xgboost as xgb
class Plotting(xgb.callback.TrainingCallback): class Plotting(xgb.callback.TrainingCallback):
'''Plot evaluation result during training. Only for demonstration purpose as it's quite """Plot evaluation result during training. Only for demonstration purpose as it's quite
slow to draw. slow to draw.
''' """
def __init__(self, rounds): def __init__(self, rounds):
self.fig = plt.figure() self.fig = plt.figure()
self.ax = self.fig.add_subplot(111) self.ax = self.fig.add_subplot(111)
@ -31,16 +32,16 @@ class Plotting(xgb.callback.TrainingCallback):
plt.ion() plt.ion()
def _get_key(self, data, metric): def _get_key(self, data, metric):
return f'{data}-{metric}' return f"{data}-{metric}"
def after_iteration(self, model, epoch, evals_log): def after_iteration(self, model, epoch, evals_log):
'''Update the plot.''' """Update the plot."""
if not self.lines: if not self.lines:
for data, metric in evals_log.items(): for data, metric in evals_log.items():
for metric_name, log in metric.items(): for metric_name, log in metric.items():
key = self._get_key(data, metric_name) key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log)) expanded = log + [0] * (self.rounds - len(log))
self.lines[key], = self.ax.plot(self.x, expanded, label=key) (self.lines[key],) = self.ax.plot(self.x, expanded, label=key)
self.ax.legend() self.ax.legend()
else: else:
# https://pythonspot.com/matplotlib-update-plot/ # https://pythonspot.com/matplotlib-update-plot/
@ -55,8 +56,8 @@ class Plotting(xgb.callback.TrainingCallback):
def custom_callback(): def custom_callback():
'''Demo for defining a custom callback function that plots evaluation result during """Demo for defining a custom callback function that plots evaluation result during
training.''' training."""
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0) X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)
@ -69,15 +70,16 @@ def custom_callback():
# Pass it to the `callbacks` parameter as a list. # Pass it to the `callbacks` parameter as a list.
xgb.train( xgb.train(
{ {
'objective': 'binary:logistic', "objective": "binary:logistic",
'eval_metric': ['error', 'rmse'], "eval_metric": ["error", "rmse"],
'tree_method': 'hist', "tree_method": "hist",
"device": "cuda", "device": "cuda",
}, },
D_train, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')], evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=num_boost_round, num_boost_round=num_boost_round,
callbacks=[plotting]) callbacks=[plotting],
)
def check_point_callback(): def check_point_callback():
@ -90,10 +92,10 @@ def check_point_callback():
if i == 0: if i == 0:
continue continue
if as_pickle: if as_pickle:
path = os.path.join(tmpdir, 'model_' + str(i) + '.pkl') path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
else: else:
path = os.path.join(tmpdir, 'model_' + str(i) + '.json') path = os.path.join(tmpdir, "model_" + str(i) + ".json")
assert(os.path.exists(path)) assert os.path.exists(path)
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
m = xgb.DMatrix(X, y) m = xgb.DMatrix(X, y)
@ -101,31 +103,36 @@ def check_point_callback():
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
# Use callback class from xgboost.callback # Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need. # Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir, check_point = xgb.callback.TrainingCheckPoint(
iterations=rounds, directory=tmpdir, iterations=rounds, name="model"
name='model') )
xgb.train({'objective': 'binary:logistic'}, m, xgb.train(
num_boost_round=10, {"objective": "binary:logistic"},
verbose_eval=False, m,
callbacks=[check_point]) num_boost_round=10,
verbose_eval=False,
callbacks=[check_point],
)
check(False) check(False)
# This version of checkpoint saves everything including parameters and # This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst # model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir, check_point = xgb.callback.TrainingCheckPoint(
iterations=rounds, directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
as_pickle=True, )
name='model') xgb.train(
xgb.train({'objective': 'binary:logistic'}, m, {"objective": "binary:logistic"},
num_boost_round=10, m,
verbose_eval=False, num_boost_round=10,
callbacks=[check_point]) verbose_eval=False,
callbacks=[check_point],
)
check(True) check(True)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--plot', default=1, type=int) parser.add_argument("--plot", default=1, type=int)
args = parser.parse_args() args = parser.parse_args()
check_point_callback() check_point_callback()

View File

@ -37,3 +37,7 @@ The sliced model is a copy of selected trees, that means the model itself is imm
during slicing. This feature is the basis of `save_best` option in early stopping during slicing. This feature is the basis of `save_best` option in early stopping
callback. See :ref:`sphx_glr_python_examples_individual_trees.py` for a worked example on callback. See :ref:`sphx_glr_python_examples_individual_trees.py` for a worked example on
how to combine prediction with sliced trees. how to combine prediction with sliced trees.
.. note::
The returned model slice doesn't contain attributes like :py:class:`~xgboost.Booster.best_iteration` and :py:class:`~xgboost.Booster.best_score`.

View File

@ -134,13 +134,17 @@ class CallbackContainer:
is_cv: bool = False, is_cv: bool = False,
) -> None: ) -> None:
self.callbacks = set(callbacks) self.callbacks = set(callbacks)
if metric is not None: for cb in callbacks:
msg = ( if not isinstance(cb, TrainingCallback):
"metric must be callable object for monitoring. For " raise TypeError("callback must be an instance of `TrainingCallback`.")
+ "builtin metrics, passing them in training parameter"
+ " will invoke monitor automatically." msg = (
) "metric must be callable object for monitoring. For builtin metrics"
assert callable(metric), msg ", passing them in training parameter invokes monitor automatically."
)
if metric is not None and not callable(metric):
raise TypeError(msg)
self.metric = metric self.metric = metric
self.history: TrainingCallback.EvalsLog = collections.OrderedDict() self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
self._output_margin = output_margin self._output_margin = output_margin
@ -170,16 +174,6 @@ class CallbackContainer:
else: else:
assert isinstance(model, Booster), msg assert isinstance(model, Booster), msg
if not self.is_cv:
if model.attr("best_score") is not None:
model.best_score = float(cast(str, model.attr("best_score")))
model.best_iteration = int(cast(str, model.attr("best_iteration")))
else:
# Due to compatibility with version older than 1.4, these attributes are
# added to Python object even if early stopping is not used.
model.best_iteration = model.num_boosted_rounds() - 1
model.set_attr(best_iteration=str(model.best_iteration))
return model return model
def before_iteration( def before_iteration(
@ -267,9 +261,14 @@ class LearningRateScheduler(TrainingCallback):
def __init__( def __init__(
self, learning_rates: Union[Callable[[int], float], Sequence[float]] self, learning_rates: Union[Callable[[int], float], Sequence[float]]
) -> None: ) -> None:
assert callable(learning_rates) or isinstance( if not callable(learning_rates) and not isinstance(
learning_rates, collections.abc.Sequence learning_rates, collections.abc.Sequence
) ):
raise TypeError(
"Invalid learning rates, expecting callable or sequence, got: "
f"{type(learning_rates)}"
)
if callable(learning_rates): if callable(learning_rates):
self.learning_rates = learning_rates self.learning_rates = learning_rates
else: else:
@ -302,24 +301,28 @@ class EarlyStopping(TrainingCallback):
save_best : save_best :
Whether training should return the best model or the last model. Whether training should return the best model or the last model.
min_delta : min_delta :
Minimum absolute change in score to be qualified as an improvement.
.. versionadded:: 1.5.0 .. versionadded:: 1.5.0
.. code-block:: python Minimum absolute change in score to be qualified as an improvement.
es = xgboost.callback.EarlyStopping( Examples
rounds=2, --------
min_delta=1e-3,
save_best=True,
maximize=False,
data_name="validation_0",
metric_name="mlogloss",
)
clf = xgboost.XGBClassifier(tree_method="gpu_hist", callbacks=[es])
X, y = load_digits(return_X_y=True) .. code-block:: python
clf.fit(X, y, eval_set=[(X, y)])
es = xgboost.callback.EarlyStopping(
rounds=2,
min_delta=1e-3,
save_best=True,
maximize=False,
data_name="validation_0",
metric_name="mlogloss",
)
clf = xgboost.XGBClassifier(tree_method="hist", device="cuda", callbacks=[es])
X, y = load_digits(return_X_y=True)
clf.fit(X, y, eval_set=[(X, y)])
""" """
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
@ -363,7 +366,7 @@ class EarlyStopping(TrainingCallback):
return numpy.greater(get_s(new) - self._min_delta, get_s(best)) return numpy.greater(get_s(new) - self._min_delta, get_s(best))
def minimize(new: _Score, best: _Score) -> bool: def minimize(new: _Score, best: _Score) -> bool:
"""New score should be smaller than the old one.""" """New score should be lesser than the old one."""
return numpy.greater(get_s(best) - self._min_delta, get_s(new)) return numpy.greater(get_s(best) - self._min_delta, get_s(new))
if self.maximize is None: if self.maximize is None:
@ -419,38 +422,53 @@ class EarlyStopping(TrainingCallback):
) -> bool: ) -> bool:
epoch += self.starting_round # training continuation epoch += self.starting_round # training continuation
msg = "Must have at least 1 validation dataset for early stopping." msg = "Must have at least 1 validation dataset for early stopping."
assert len(evals_log.keys()) >= 1, msg if len(evals_log.keys()) < 1:
data_name = "" raise ValueError(msg)
# Get data name
if self.data: if self.data:
for d, _ in evals_log.items(): data_name = self.data
if d == self.data:
data_name = d
if not data_name:
raise ValueError("No dataset named:", self.data)
else: else:
# Use the last one as default. # Use the last one as default.
data_name = list(evals_log.keys())[-1] data_name = list(evals_log.keys())[-1]
assert isinstance(data_name, str) and data_name if data_name not in evals_log:
raise ValueError(f"No dataset named: {data_name}")
if not isinstance(data_name, str):
raise TypeError(
f"The name of the dataset should be a string. Got: {type(data_name)}"
)
data_log = evals_log[data_name] data_log = evals_log[data_name]
# Filter out scores that can not be used for early stopping. # Get metric name
if self.metric_name: if self.metric_name:
metric_name = self.metric_name metric_name = self.metric_name
else: else:
# Use last metric by default. # Use last metric by default.
assert isinstance(data_log, collections.OrderedDict)
metric_name = list(data_log.keys())[-1] metric_name = list(data_log.keys())[-1]
if metric_name not in data_log:
raise ValueError(f"No metric named: {metric_name}")
# The latest score
score = data_log[metric_name][-1] score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch) return self._update_rounds(score, data_name, metric_name, model, epoch)
def after_training(self, model: _Model) -> _Model: def after_training(self, model: _Model) -> _Model:
if not self.save_best:
return model
try: try:
if self.save_best: best_iteration = model.best_iteration
model = model[: int(model.attr("best_iteration")) + 1] best_score = model.best_score
assert best_iteration is not None and best_score is not None
model = model[: best_iteration + 1]
model.best_iteration = best_iteration
model.best_score = best_score
except XGBoostError as e: except XGBoostError as e:
raise XGBoostError( raise XGBoostError(
"`save_best` is not applicable to current booster" "`save_best` is not applicable to the current booster"
) from e ) from e
return model return model
@ -462,8 +480,6 @@ class EvaluationMonitor(TrainingCallback):
Parameters Parameters
---------- ----------
metric :
Extra user defined metric.
rank : rank :
Which worker should be used for printing the result. Which worker should be used for printing the result.
period : period :

View File

@ -1890,7 +1890,7 @@ class Booster:
attr_names = from_cstr_to_pystr(sarr, length) attr_names = from_cstr_to_pystr(sarr, length)
return {n: self.attr(n) for n in attr_names} return {n: self.attr(n) for n in attr_names}
def set_attr(self, **kwargs: Optional[str]) -> None: def set_attr(self, **kwargs: Optional[Any]) -> None:
"""Set the attribute of the Booster. """Set the attribute of the Booster.
Parameters Parameters
@ -2559,10 +2559,35 @@ class Booster:
else: else:
raise TypeError("Unknown file type: ", fname) raise TypeError("Unknown file type: ", fname)
if self.attr("best_iteration") is not None: @property
self.best_iteration = int(cast(int, self.attr("best_iteration"))) def best_iteration(self) -> int:
if self.attr("best_score") is not None: """The best iteration during training."""
self.best_score = float(cast(float, self.attr("best_score"))) best = self.attr("best_iteration")
if best is not None:
return int(best)
raise AttributeError(
"`best_iteration` is only defined when early stopping is used."
)
@best_iteration.setter
def best_iteration(self, iteration: int) -> None:
self.set_attr(best_iteration=iteration)
@property
def best_score(self) -> float:
"""The best evaluation score during training."""
best = self.attr("best_score")
if best is not None:
return float(best)
raise AttributeError(
"`best_score` is only defined when early stopping is used."
)
@best_score.setter
def best_score(self, score: int) -> None:
self.set_attr(best_score=score)
def num_boosted_rounds(self) -> int: def num_boosted_rounds(self) -> int:
"""Get number of boosted rounds. For gblinear this is reset to 0 after """Get number of boosted rounds. For gblinear this is reset to 0 after

View File

@ -230,10 +230,10 @@ __model_doc = f"""
subsample : Optional[float] subsample : Optional[float]
Subsample ratio of the training instance. Subsample ratio of the training instance.
sampling_method : sampling_method :
Sampling method. Used only by `gpu_hist` tree method. Sampling method. Used only by the GPU version of ``hist`` tree method.
- `uniform`: select random training instances uniformly. - ``uniform``: select random training instances uniformly.
- `gradient_based` select random training instances with higher probability when - ``gradient_based`` select random training instances with higher probability
the gradient and hessian are larger. (cf. CatBoost) when the gradient and hessian are larger. (cf. CatBoost)
colsample_bytree : Optional[float] colsample_bytree : Optional[float]
Subsample ratio of columns when constructing each tree. Subsample ratio of columns when constructing each tree.
colsample_bylevel : Optional[float] colsample_bylevel : Optional[float]
@ -992,12 +992,12 @@ class XGBModel(XGBModelBase):
X : X :
Feature matrix. See :ref:`py-data` for a list of supported types. Feature matrix. See :ref:`py-data` for a list of supported types.
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the When the ``tree_method`` is set to ``hist``, internally, the
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix` :py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
for conserving memory. However, this has performance implications when the for conserving memory. However, this has performance implications when the
device of input data is not matched with algorithm. For instance, if the device of input data is not matched with algorithm. For instance, if the
input is a numpy array on CPU but ``gpu_hist`` is used for training, then input is a numpy array on CPU but ``cuda`` is used for training, then the
the data is first processed on CPU then transferred to GPU. data is first processed on CPU then transferred to GPU.
y : y :
Labels Labels
sample_weight : sample_weight :
@ -1279,19 +1279,10 @@ class XGBModel(XGBModelBase):
) )
return np.array(feature_names) return np.array(feature_names)
def _early_stopping_attr(self, attr: str) -> Union[float, int]:
booster = self.get_booster()
try:
return getattr(booster, attr)
except AttributeError as e:
raise AttributeError(
f"`{attr}` in only defined when early stopping is used."
) from e
@property @property
def best_score(self) -> float: def best_score(self) -> float:
"""The best score obtained by early stopping.""" """The best score obtained by early stopping."""
return float(self._early_stopping_attr("best_score")) return self.get_booster().best_score
@property @property
def best_iteration(self) -> int: def best_iteration(self) -> int:
@ -1299,7 +1290,7 @@ class XGBModel(XGBModelBase):
for instance if the best iteration is the first round, then best_iteration is 0. for instance if the best iteration is the first round, then best_iteration is 0.
""" """
return int(self._early_stopping_attr("best_iteration")) return self.get_booster().best_iteration
@property @property
def feature_importances_(self) -> np.ndarray: def feature_importances_(self) -> np.ndarray:
@ -1926,12 +1917,12 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
| 1 | :math:`x_{20}` | :math:`x_{21}` | | 1 | :math:`x_{20}` | :math:`x_{21}` |
+-----+----------------+----------------+ +-----+----------------+----------------+
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the When the ``tree_method`` is set to ``hist``, internally, the
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix` :py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
for conserving memory. However, this has performance implications when the for conserving memory. However, this has performance implications when the
device of input data is not matched with algorithm. For instance, if the device of input data is not matched with algorithm. For instance, if the
input is a numpy array on CPU but ``gpu_hist`` is used for training, then input is a numpy array on CPU but ``cuda`` is used for training, then the
the data is first processed on CPU then transferred to GPU. data is first processed on CPU then transferred to GPU.
y : y :
Labels Labels
group : group :

View File

@ -28,17 +28,6 @@ from .core import (
_CVFolds = Sequence["CVPack"] _CVFolds = Sequence["CVPack"]
def _assert_new_callback(callbacks: Optional[Sequence[TrainingCallback]]) -> None:
is_new_callback: bool = not callbacks or all(
isinstance(c, TrainingCallback) for c in callbacks
)
if not is_new_callback:
link = "https://xgboost.readthedocs.io/en/latest/python/callbacks.html"
raise ValueError(
f"Old style callback was removed in version 1.6. See: {link}."
)
def _configure_custom_metric( def _configure_custom_metric(
feval: Optional[Metric], custom_metric: Optional[Metric] feval: Optional[Metric], custom_metric: Optional[Metric]
) -> Optional[Metric]: ) -> Optional[Metric]:
@ -170,7 +159,6 @@ def train(
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
start_iteration = 0 start_iteration = 0
_assert_new_callback(callbacks)
if verbose_eval: if verbose_eval:
verbose_eval = 1 if verbose_eval is True else verbose_eval verbose_eval = 1 if verbose_eval is True else verbose_eval
callbacks.append(EvaluationMonitor(period=verbose_eval)) callbacks.append(EvaluationMonitor(period=verbose_eval))
@ -247,7 +235,7 @@ class _PackedBooster:
result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds] result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
return result return result
def set_attr(self, **kwargs: Optional[str]) -> Any: def set_attr(self, **kwargs: Optional[Any]) -> Any:
"""Iterate through folds for setting attributes""" """Iterate through folds for setting attributes"""
for f in self.cvfolds: for f in self.cvfolds:
f.bst.set_attr(**kwargs) f.bst.set_attr(**kwargs)
@ -274,11 +262,20 @@ class _PackedBooster:
"""Get best_iteration""" """Get best_iteration"""
return int(cast(int, self.cvfolds[0].bst.attr("best_iteration"))) return int(cast(int, self.cvfolds[0].bst.attr("best_iteration")))
@best_iteration.setter
def best_iteration(self, iteration: int) -> None:
"""Get best_iteration"""
self.set_attr(best_iteration=iteration)
@property @property
def best_score(self) -> float: def best_score(self) -> float:
"""Get best_score.""" """Get best_score."""
return float(cast(float, self.cvfolds[0].bst.attr("best_score"))) return float(cast(float, self.cvfolds[0].bst.attr("best_score")))
@best_score.setter
def best_score(self, score: float) -> None:
self.set_attr(best_score=score)
def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarray: def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarray:
""" """
@ -551,7 +548,6 @@ def cv(
# setup callbacks # setup callbacks
callbacks = [] if callbacks is None else copy.copy(list(callbacks)) callbacks = [] if callbacks is None else copy.copy(list(callbacks))
_assert_new_callback(callbacks)
if verbose_eval: if verbose_eval:
verbose_eval = 1 if verbose_eval is True else verbose_eval verbose_eval = 1 if verbose_eval is True else verbose_eval

View File

@ -37,6 +37,7 @@ class LintersPaths:
"demo/rmm_plugin", "demo/rmm_plugin",
"demo/json-model/json_parser.py", "demo/json-model/json_parser.py",
"demo/guide-python/cat_in_the_dat.py", "demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/callbacks.py",
"demo/guide-python/categorical.py", "demo/guide-python/categorical.py",
"demo/guide-python/feature_weights.py", "demo/guide-python/feature_weights.py",
"demo/guide-python/sklearn_parallel.py", "demo/guide-python/sklearn_parallel.py",

View File

@ -1,7 +1,6 @@
import json import json
import os import os
import tempfile import tempfile
from contextlib import nullcontext
from typing import Union from typing import Union
import pytest import pytest
@ -104,15 +103,6 @@ class TestCallbacks:
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
# No early stopping, best_iteration should be set to last epoch
booster = xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=10,
evals_result=evals_result,
verbose_eval=True)
assert booster.num_boosted_rounds() - 1 == booster.best_iteration
def test_early_stopping_custom_eval(self): def test_early_stopping_custom_eval(self):
D_train = xgb.DMatrix(self.X_train, self.y_train) D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid) D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
@ -204,8 +194,9 @@ class TestCallbacks:
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
n_estimators = 100 n_estimators = 100
early_stopping_rounds = 5 early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(
save_best=True) rounds=early_stopping_rounds, save_best=True
)
cls = xgb.XGBClassifier( cls = xgb.XGBClassifier(
n_estimators=n_estimators, n_estimators=n_estimators,
eval_metric=tm.eval_error_metric_skl, eval_metric=tm.eval_error_metric_skl,
@ -216,20 +207,27 @@ class TestCallbacks:
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) == booster.best_iteration + 1 assert len(dump) == booster.best_iteration + 1
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(
save_best=True) rounds=early_stopping_rounds, save_best=True
)
cls = xgb.XGBClassifier( cls = xgb.XGBClassifier(
booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl booster="gblinear",
n_estimators=10,
eval_metric=tm.eval_error_metric_skl,
callbacks=[early_stop],
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) cls.fit(X, y, eval_set=[(X, y)])
# No error # No error
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=False) save_best=False)
xgb.XGBClassifier( xgb.XGBClassifier(
booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl booster="gblinear",
).fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) n_estimators=10,
eval_metric=tm.eval_error_metric_skl,
callbacks=[early_stop],
).fit(X, y, eval_set=[(X, y)])
def test_early_stopping_continuation(self): def test_early_stopping_continuation(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
@ -252,8 +250,11 @@ class TestCallbacks:
cls.load_model(path) cls.load_model(path)
assert cls._Booster is not None assert cls._Booster is not None
early_stopping_rounds = 3 early_stopping_rounds = 3
cls.set_params(eval_metric=tm.eval_error_metric_skl) cls.set_params(
cls.fit(X, y, eval_set=[(X, y)], early_stopping_rounds=early_stopping_rounds) eval_metric=tm.eval_error_metric_skl,
early_stopping_rounds=early_stopping_rounds,
)
cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster() booster = cls.get_booster()
assert booster.num_boosted_rounds() == \ assert booster.num_boosted_rounds() == \
booster.best_iteration + early_stopping_rounds + 1 booster.best_iteration + early_stopping_rounds + 1
@ -280,20 +281,20 @@ class TestCallbacks:
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4 num_round = 4
warning_check = nullcontext()
# learning_rates as a list # learning_rates as a list
# init eta with 0 to check whether learning_rates work # init eta with 0 to check whether learning_rates work
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0, param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
'objective': 'binary:logistic', 'eval_metric': 'error', 'objective': 'binary:logistic', 'eval_metric': 'error',
'tree_method': tree_method} 'tree_method': tree_method}
evals_result = {} evals_result = {}
with warning_check: bst = xgb.train(
bst = xgb.train(param, dtrain, num_round, watchlist, param,
callbacks=[scheduler([ dtrain,
0.8, 0.7, 0.6, 0.5 num_round,
])], evals=watchlist,
evals_result=evals_result) callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
evals_result=evals_result,
)
eval_errors_0 = list(map(float, evals_result['eval']['error'])) eval_errors_0 = list(map(float, evals_result['eval']['error']))
assert isinstance(bst, xgb.core.Booster) assert isinstance(bst, xgb.core.Booster)
# validation error should decrease, if eta > 0 # validation error should decrease, if eta > 0
@ -304,11 +305,15 @@ class TestCallbacks:
'objective': 'binary:logistic', 'eval_metric': 'error', 'objective': 'binary:logistic', 'eval_metric': 'error',
'tree_method': tree_method} 'tree_method': tree_method}
evals_result = {} evals_result = {}
with warning_check:
bst = xgb.train(param, dtrain, num_round, watchlist, bst = xgb.train(
callbacks=[scheduler( param,
[0.8, 0.7, 0.6, 0.5])], dtrain,
evals_result=evals_result) num_round,
evals=watchlist,
callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
evals_result=evals_result,
)
eval_errors_1 = list(map(float, evals_result['eval']['error'])) eval_errors_1 = list(map(float, evals_result['eval']['error']))
assert isinstance(bst, xgb.core.Booster) assert isinstance(bst, xgb.core.Booster)
# validation error should decrease, if learning_rate > 0 # validation error should decrease, if learning_rate > 0
@ -320,12 +325,14 @@ class TestCallbacks:
'eval_metric': 'error', 'tree_method': tree_method 'eval_metric': 'error', 'tree_method': tree_method
} }
evals_result = {} evals_result = {}
with warning_check: bst = xgb.train(
bst = xgb.train(param, dtrain, num_round, watchlist, param,
callbacks=[scheduler( dtrain,
[0, 0, 0, 0] num_round,
)], evals=watchlist,
evals_result=evals_result) callbacks=[scheduler([0, 0, 0, 0])],
evals_result=evals_result,
)
eval_errors_2 = list(map(float, evals_result['eval']['error'])) eval_errors_2 = list(map(float, evals_result['eval']['error']))
assert isinstance(bst, xgb.core.Booster) assert isinstance(bst, xgb.core.Booster)
# validation error should not decrease, if eta/learning_rate = 0 # validation error should not decrease, if eta/learning_rate = 0
@ -336,12 +343,14 @@ class TestCallbacks:
return num_boost_round / (ithround + 1) return num_boost_round / (ithround + 1)
evals_result = {} evals_result = {}
with warning_check: bst = xgb.train(
bst = xgb.train(param, dtrain, num_round, watchlist, param,
callbacks=[ dtrain,
scheduler(eta_decay) num_round,
], evals=watchlist,
evals_result=evals_result) callbacks=[scheduler(eta_decay)],
evals_result=evals_result,
)
eval_errors_3 = list(map(float, evals_result['eval']['error'])) eval_errors_3 = list(map(float, evals_result['eval']['error']))
assert isinstance(bst, xgb.core.Booster) assert isinstance(bst, xgb.core.Booster)
@ -351,8 +360,7 @@ class TestCallbacks:
for i in range(1, len(eval_errors_0)): for i in range(1, len(eval_errors_0)):
assert eval_errors_3[i] != eval_errors_2[i] assert eval_errors_3[i] != eval_errors_2[i]
with warning_check: xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
def run_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None: def run_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
# check decay has effect on leaf output. # check decay has effect on leaf output.
@ -378,7 +386,7 @@ class TestCallbacks:
param, param,
dtrain, dtrain,
num_round, num_round,
watchlist, evals=watchlist,
callbacks=[scheduler(eta_decay_0)], callbacks=[scheduler(eta_decay_0)],
) )
@ -391,7 +399,7 @@ class TestCallbacks:
param, param,
dtrain, dtrain,
num_round, num_round,
watchlist, evals=watchlist,
callbacks=[scheduler(eta_decay_1)], callbacks=[scheduler(eta_decay_1)],
) )
bst_json0 = bst0.save_raw(raw_format="json") bst_json0 = bst0.save_raw(raw_format="json")
@ -474,3 +482,24 @@ class TestCallbacks:
callbacks=callbacks, callbacks=callbacks,
) )
assert len(callbacks) == 1 assert len(callbacks) == 1
def test_attribute_error(self) -> None:
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
clf = xgb.XGBClassifier(n_estimators=8)
clf.fit(X, y, eval_set=[(X, y)])
with pytest.raises(AttributeError, match="early stopping is used"):
clf.best_iteration
with pytest.raises(AttributeError, match="early stopping is used"):
clf.best_score
booster = clf.get_booster()
with pytest.raises(AttributeError, match="early stopping is used"):
booster.best_iteration
with pytest.raises(AttributeError, match="early stopping is used"):
booster.best_score

View File

@ -173,7 +173,7 @@ class TestInplacePredict:
np.testing.assert_allclose(predt_from_dmatrix, predt_from_array) np.testing.assert_allclose(predt_from_dmatrix, predt_from_array)
with pytest.raises(ValueError): with pytest.raises(ValueError):
booster.predict(test, iteration_range=(0, booster.best_iteration + 2)) booster.predict(test, iteration_range=(0, booster.num_boosted_rounds() + 2))
default = booster.predict(test) default = booster.predict(test)
@ -181,7 +181,7 @@ class TestInplacePredict:
np.testing.assert_allclose(range_full, default) np.testing.assert_allclose(range_full, default)
range_full = booster.predict( range_full = booster.predict(
test, iteration_range=(0, booster.best_iteration + 1) test, iteration_range=(0, booster.num_boosted_rounds())
) )
np.testing.assert_allclose(range_full, default) np.testing.assert_allclose(range_full, default)

View File

@ -100,8 +100,8 @@ class TestTrainingContinuation:
res2 = mean_squared_error( res2 = mean_squared_error(
y_2class, y_2class,
gbdt_04.predict( gbdt_04.predict(
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1) dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
) ),
) )
assert res1 == res2 assert res1 == res2
@ -112,7 +112,7 @@ class TestTrainingContinuation:
res2 = mean_squared_error( res2 = mean_squared_error(
y_2class, y_2class,
gbdt_04.predict( gbdt_04.predict(
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1) dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
) )
) )
assert res1 == res2 assert res1 == res2
@ -126,7 +126,7 @@ class TestTrainingContinuation:
res1 = gbdt_05.predict(dtrain_5class) res1 = gbdt_05.predict(dtrain_5class)
res2 = gbdt_05.predict( res2 = gbdt_05.predict(
dtrain_5class, iteration_range=(0, gbdt_05.best_iteration + 1) dtrain_5class, iteration_range=(0, gbdt_05.num_boosted_rounds())
) )
np.testing.assert_almost_equal(res1, res2) np.testing.assert_almost_equal(res1, res2)
@ -138,15 +138,16 @@ class TestTrainingContinuation:
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_updaters_json(self): def test_training_continuation_updaters_json(self):
# Picked up from R tests. # Picked up from R tests.
updaters = 'grow_colmaker,prune,refresh' updaters = "grow_colmaker,prune,refresh"
params = self.generate_parameters() params = self.generate_parameters()
for p in params: for p in params:
p['updater'] = updaters p["updater"] = updaters
self.run_training_continuation(params[0], params[1], params[2]) self.run_training_continuation(params[0], params[1], params[2])
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_changed_parameter(self): def test_changed_parameter(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
clf = xgb.XGBClassifier(n_estimators=2) clf = xgb.XGBClassifier(n_estimators=2)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss") clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")