Define best_iteration only if early stopping is used. (#9403)

* Define `best_iteration` only if early stopping is used.

This is the behavior specified by the document but not honored in the actual code.

- Don't set the attributes if there's no early stopping.
- Clean up the code for callbacks, and replace assertions with proper exceptions.
- Assign the attributes when early stopping `save_best` is used.
- Turn the attributes into Python properties.

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2023-07-24 12:43:35 +08:00 committed by GitHub
parent 01e00efc53
commit 851cba931e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 249 additions and 179 deletions

View File

@ -1,9 +1,9 @@
'''
"""
Demo for using and defining callback functions
==============================================
.. versionadded:: 1.3.0
'''
"""
import argparse
import os
import tempfile
@ -17,10 +17,11 @@ import xgboost as xgb
class Plotting(xgb.callback.TrainingCallback):
'''Plot evaluation result during training. Only for demonstration purpose as it's quite
"""Plot evaluation result during training. Only for demonstration purpose as it's quite
slow to draw.
'''
"""
def __init__(self, rounds):
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
@ -31,16 +32,16 @@ class Plotting(xgb.callback.TrainingCallback):
plt.ion()
def _get_key(self, data, metric):
return f'{data}-{metric}'
return f"{data}-{metric}"
def after_iteration(self, model, epoch, evals_log):
'''Update the plot.'''
"""Update the plot."""
if not self.lines:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
self.lines[key], = self.ax.plot(self.x, expanded, label=key)
(self.lines[key],) = self.ax.plot(self.x, expanded, label=key)
self.ax.legend()
else:
# https://pythonspot.com/matplotlib-update-plot/
@ -55,8 +56,8 @@ class Plotting(xgb.callback.TrainingCallback):
def custom_callback():
'''Demo for defining a custom callback function that plots evaluation result during
training.'''
"""Demo for defining a custom callback function that plots evaluation result during
training."""
X, y = load_breast_cancer(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)
@ -69,15 +70,16 @@ def custom_callback():
# Pass it to the `callbacks` parameter as a list.
xgb.train(
{
'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'hist',
"objective": "binary:logistic",
"eval_metric": ["error", "rmse"],
"tree_method": "hist",
"device": "cuda",
},
D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=num_boost_round,
callbacks=[plotting])
callbacks=[plotting],
)
def check_point_callback():
@ -90,10 +92,10 @@ def check_point_callback():
if i == 0:
continue
if as_pickle:
path = os.path.join(tmpdir, 'model_' + str(i) + '.pkl')
path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
else:
path = os.path.join(tmpdir, 'model_' + str(i) + '.json')
assert(os.path.exists(path))
path = os.path.join(tmpdir, "model_" + str(i) + ".json")
assert os.path.exists(path)
X, y = load_breast_cancer(return_X_y=True)
m = xgb.DMatrix(X, y)
@ -101,31 +103,36 @@ def check_point_callback():
with tempfile.TemporaryDirectory() as tmpdir:
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
iterations=rounds,
name='model')
xgb.train({'objective': 'binary:logistic'}, m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point])
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point],
)
check(False)
# This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
iterations=rounds,
as_pickle=True,
name='model')
xgb.train({'objective': 'binary:logistic'}, m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point])
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point],
)
check(True)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--plot', default=1, type=int)
parser.add_argument("--plot", default=1, type=int)
args = parser.parse_args()
check_point_callback()

View File

@ -37,3 +37,7 @@ The sliced model is a copy of selected trees, that means the model itself is imm
during slicing. This feature is the basis of `save_best` option in early stopping
callback. See :ref:`sphx_glr_python_examples_individual_trees.py` for a worked example on
how to combine prediction with sliced trees.
.. note::
The returned model slice doesn't contain attributes like :py:class:`~xgboost.Booster.best_iteration` and :py:class:`~xgboost.Booster.best_score`.

View File

@ -134,13 +134,17 @@ class CallbackContainer:
is_cv: bool = False,
) -> None:
self.callbacks = set(callbacks)
if metric is not None:
msg = (
"metric must be callable object for monitoring. For "
+ "builtin metrics, passing them in training parameter"
+ " will invoke monitor automatically."
)
assert callable(metric), msg
for cb in callbacks:
if not isinstance(cb, TrainingCallback):
raise TypeError("callback must be an instance of `TrainingCallback`.")
msg = (
"metric must be callable object for monitoring. For builtin metrics"
", passing them in training parameter invokes monitor automatically."
)
if metric is not None and not callable(metric):
raise TypeError(msg)
self.metric = metric
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
self._output_margin = output_margin
@ -170,16 +174,6 @@ class CallbackContainer:
else:
assert isinstance(model, Booster), msg
if not self.is_cv:
if model.attr("best_score") is not None:
model.best_score = float(cast(str, model.attr("best_score")))
model.best_iteration = int(cast(str, model.attr("best_iteration")))
else:
# Due to compatibility with version older than 1.4, these attributes are
# added to Python object even if early stopping is not used.
model.best_iteration = model.num_boosted_rounds() - 1
model.set_attr(best_iteration=str(model.best_iteration))
return model
def before_iteration(
@ -267,9 +261,14 @@ class LearningRateScheduler(TrainingCallback):
def __init__(
self, learning_rates: Union[Callable[[int], float], Sequence[float]]
) -> None:
assert callable(learning_rates) or isinstance(
if not callable(learning_rates) and not isinstance(
learning_rates, collections.abc.Sequence
)
):
raise TypeError(
"Invalid learning rates, expecting callable or sequence, got: "
f"{type(learning_rates)}"
)
if callable(learning_rates):
self.learning_rates = learning_rates
else:
@ -302,24 +301,28 @@ class EarlyStopping(TrainingCallback):
save_best :
Whether training should return the best model or the last model.
min_delta :
Minimum absolute change in score to be qualified as an improvement.
.. versionadded:: 1.5.0
.. code-block:: python
Minimum absolute change in score to be qualified as an improvement.
es = xgboost.callback.EarlyStopping(
rounds=2,
min_delta=1e-3,
save_best=True,
maximize=False,
data_name="validation_0",
metric_name="mlogloss",
)
clf = xgboost.XGBClassifier(tree_method="gpu_hist", callbacks=[es])
Examples
--------
X, y = load_digits(return_X_y=True)
clf.fit(X, y, eval_set=[(X, y)])
.. code-block:: python
es = xgboost.callback.EarlyStopping(
rounds=2,
min_delta=1e-3,
save_best=True,
maximize=False,
data_name="validation_0",
metric_name="mlogloss",
)
clf = xgboost.XGBClassifier(tree_method="hist", device="cuda", callbacks=[es])
X, y = load_digits(return_X_y=True)
clf.fit(X, y, eval_set=[(X, y)])
"""
# pylint: disable=too-many-arguments
@ -363,7 +366,7 @@ class EarlyStopping(TrainingCallback):
return numpy.greater(get_s(new) - self._min_delta, get_s(best))
def minimize(new: _Score, best: _Score) -> bool:
"""New score should be smaller than the old one."""
"""New score should be lesser than the old one."""
return numpy.greater(get_s(best) - self._min_delta, get_s(new))
if self.maximize is None:
@ -419,38 +422,53 @@ class EarlyStopping(TrainingCallback):
) -> bool:
epoch += self.starting_round # training continuation
msg = "Must have at least 1 validation dataset for early stopping."
assert len(evals_log.keys()) >= 1, msg
data_name = ""
if len(evals_log.keys()) < 1:
raise ValueError(msg)
# Get data name
if self.data:
for d, _ in evals_log.items():
if d == self.data:
data_name = d
if not data_name:
raise ValueError("No dataset named:", self.data)
data_name = self.data
else:
# Use the last one as default.
data_name = list(evals_log.keys())[-1]
assert isinstance(data_name, str) and data_name
if data_name not in evals_log:
raise ValueError(f"No dataset named: {data_name}")
if not isinstance(data_name, str):
raise TypeError(
f"The name of the dataset should be a string. Got: {type(data_name)}"
)
data_log = evals_log[data_name]
# Filter out scores that can not be used for early stopping.
# Get metric name
if self.metric_name:
metric_name = self.metric_name
else:
# Use last metric by default.
assert isinstance(data_log, collections.OrderedDict)
metric_name = list(data_log.keys())[-1]
if metric_name not in data_log:
raise ValueError(f"No metric named: {metric_name}")
# The latest score
score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch)
def after_training(self, model: _Model) -> _Model:
if not self.save_best:
return model
try:
if self.save_best:
model = model[: int(model.attr("best_iteration")) + 1]
best_iteration = model.best_iteration
best_score = model.best_score
assert best_iteration is not None and best_score is not None
model = model[: best_iteration + 1]
model.best_iteration = best_iteration
model.best_score = best_score
except XGBoostError as e:
raise XGBoostError(
"`save_best` is not applicable to current booster"
"`save_best` is not applicable to the current booster"
) from e
return model
@ -462,8 +480,6 @@ class EvaluationMonitor(TrainingCallback):
Parameters
----------
metric :
Extra user defined metric.
rank :
Which worker should be used for printing the result.
period :

View File

@ -1890,7 +1890,7 @@ class Booster:
attr_names = from_cstr_to_pystr(sarr, length)
return {n: self.attr(n) for n in attr_names}
def set_attr(self, **kwargs: Optional[str]) -> None:
def set_attr(self, **kwargs: Optional[Any]) -> None:
"""Set the attribute of the Booster.
Parameters
@ -2559,10 +2559,35 @@ class Booster:
else:
raise TypeError("Unknown file type: ", fname)
if self.attr("best_iteration") is not None:
self.best_iteration = int(cast(int, self.attr("best_iteration")))
if self.attr("best_score") is not None:
self.best_score = float(cast(float, self.attr("best_score")))
@property
def best_iteration(self) -> int:
"""The best iteration during training."""
best = self.attr("best_iteration")
if best is not None:
return int(best)
raise AttributeError(
"`best_iteration` is only defined when early stopping is used."
)
@best_iteration.setter
def best_iteration(self, iteration: int) -> None:
self.set_attr(best_iteration=iteration)
@property
def best_score(self) -> float:
"""The best evaluation score during training."""
best = self.attr("best_score")
if best is not None:
return float(best)
raise AttributeError(
"`best_score` is only defined when early stopping is used."
)
@best_score.setter
def best_score(self, score: int) -> None:
self.set_attr(best_score=score)
def num_boosted_rounds(self) -> int:
"""Get number of boosted rounds. For gblinear this is reset to 0 after

View File

@ -230,10 +230,10 @@ __model_doc = f"""
subsample : Optional[float]
Subsample ratio of the training instance.
sampling_method :
Sampling method. Used only by `gpu_hist` tree method.
- `uniform`: select random training instances uniformly.
- `gradient_based` select random training instances with higher probability when
the gradient and hessian are larger. (cf. CatBoost)
Sampling method. Used only by the GPU version of ``hist`` tree method.
- ``uniform``: select random training instances uniformly.
- ``gradient_based`` select random training instances with higher probability
when the gradient and hessian are larger. (cf. CatBoost)
colsample_bytree : Optional[float]
Subsample ratio of columns when constructing each tree.
colsample_bylevel : Optional[float]
@ -992,12 +992,12 @@ class XGBModel(XGBModelBase):
X :
Feature matrix. See :ref:`py-data` for a list of supported types.
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the
When the ``tree_method`` is set to ``hist``, internally, the
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
for conserving memory. However, this has performance implications when the
device of input data is not matched with algorithm. For instance, if the
input is a numpy array on CPU but ``gpu_hist`` is used for training, then
the data is first processed on CPU then transferred to GPU.
input is a numpy array on CPU but ``cuda`` is used for training, then the
data is first processed on CPU then transferred to GPU.
y :
Labels
sample_weight :
@ -1279,19 +1279,10 @@ class XGBModel(XGBModelBase):
)
return np.array(feature_names)
def _early_stopping_attr(self, attr: str) -> Union[float, int]:
booster = self.get_booster()
try:
return getattr(booster, attr)
except AttributeError as e:
raise AttributeError(
f"`{attr}` in only defined when early stopping is used."
) from e
@property
def best_score(self) -> float:
"""The best score obtained by early stopping."""
return float(self._early_stopping_attr("best_score"))
return self.get_booster().best_score
@property
def best_iteration(self) -> int:
@ -1299,7 +1290,7 @@ class XGBModel(XGBModelBase):
for instance if the best iteration is the first round, then best_iteration is 0.
"""
return int(self._early_stopping_attr("best_iteration"))
return self.get_booster().best_iteration
@property
def feature_importances_(self) -> np.ndarray:
@ -1926,12 +1917,12 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
| 1 | :math:`x_{20}` | :math:`x_{21}` |
+-----+----------------+----------------+
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the
When the ``tree_method`` is set to ``hist``, internally, the
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
for conserving memory. However, this has performance implications when the
device of input data is not matched with algorithm. For instance, if the
input is a numpy array on CPU but ``gpu_hist`` is used for training, then
the data is first processed on CPU then transferred to GPU.
input is a numpy array on CPU but ``cuda`` is used for training, then the
data is first processed on CPU then transferred to GPU.
y :
Labels
group :

View File

@ -28,17 +28,6 @@ from .core import (
_CVFolds = Sequence["CVPack"]
def _assert_new_callback(callbacks: Optional[Sequence[TrainingCallback]]) -> None:
is_new_callback: bool = not callbacks or all(
isinstance(c, TrainingCallback) for c in callbacks
)
if not is_new_callback:
link = "https://xgboost.readthedocs.io/en/latest/python/callbacks.html"
raise ValueError(
f"Old style callback was removed in version 1.6. See: {link}."
)
def _configure_custom_metric(
feval: Optional[Metric], custom_metric: Optional[Metric]
) -> Optional[Metric]:
@ -170,7 +159,6 @@ def train(
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
start_iteration = 0
_assert_new_callback(callbacks)
if verbose_eval:
verbose_eval = 1 if verbose_eval is True else verbose_eval
callbacks.append(EvaluationMonitor(period=verbose_eval))
@ -247,7 +235,7 @@ class _PackedBooster:
result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
return result
def set_attr(self, **kwargs: Optional[str]) -> Any:
def set_attr(self, **kwargs: Optional[Any]) -> Any:
"""Iterate through folds for setting attributes"""
for f in self.cvfolds:
f.bst.set_attr(**kwargs)
@ -274,11 +262,20 @@ class _PackedBooster:
"""Get best_iteration"""
return int(cast(int, self.cvfolds[0].bst.attr("best_iteration")))
@best_iteration.setter
def best_iteration(self, iteration: int) -> None:
"""Get best_iteration"""
self.set_attr(best_iteration=iteration)
@property
def best_score(self) -> float:
"""Get best_score."""
return float(cast(float, self.cvfolds[0].bst.attr("best_score")))
@best_score.setter
def best_score(self, score: float) -> None:
self.set_attr(best_score=score)
def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarray:
"""
@ -551,7 +548,6 @@ def cv(
# setup callbacks
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
_assert_new_callback(callbacks)
if verbose_eval:
verbose_eval = 1 if verbose_eval is True else verbose_eval

View File

@ -37,6 +37,7 @@ class LintersPaths:
"demo/rmm_plugin",
"demo/json-model/json_parser.py",
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/callbacks.py",
"demo/guide-python/categorical.py",
"demo/guide-python/feature_weights.py",
"demo/guide-python/sklearn_parallel.py",

View File

@ -1,7 +1,6 @@
import json
import os
import tempfile
from contextlib import nullcontext
from typing import Union
import pytest
@ -104,15 +103,6 @@ class TestCallbacks:
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
# No early stopping, best_iteration should be set to last epoch
booster = xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=10,
evals_result=evals_result,
verbose_eval=True)
assert booster.num_boosted_rounds() - 1 == booster.best_iteration
def test_early_stopping_custom_eval(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
@ -204,8 +194,9 @@ class TestCallbacks:
X, y = load_breast_cancer(return_X_y=True)
n_estimators = 100
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
cls = xgb.XGBClassifier(
n_estimators=n_estimators,
eval_metric=tm.eval_error_metric_skl,
@ -216,20 +207,27 @@ class TestCallbacks:
dump = booster.get_dump(dump_format='json')
assert len(dump) == booster.best_iteration + 1
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
cls = xgb.XGBClassifier(
booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl
booster="gblinear",
n_estimators=10,
eval_metric=tm.eval_error_metric_skl,
callbacks=[early_stop],
)
with pytest.raises(ValueError):
cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
cls.fit(X, y, eval_set=[(X, y)])
# No error
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=False)
xgb.XGBClassifier(
booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl
).fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
booster="gblinear",
n_estimators=10,
eval_metric=tm.eval_error_metric_skl,
callbacks=[early_stop],
).fit(X, y, eval_set=[(X, y)])
def test_early_stopping_continuation(self):
from sklearn.datasets import load_breast_cancer
@ -252,8 +250,11 @@ class TestCallbacks:
cls.load_model(path)
assert cls._Booster is not None
early_stopping_rounds = 3
cls.set_params(eval_metric=tm.eval_error_metric_skl)
cls.fit(X, y, eval_set=[(X, y)], early_stopping_rounds=early_stopping_rounds)
cls.set_params(
eval_metric=tm.eval_error_metric_skl,
early_stopping_rounds=early_stopping_rounds,
)
cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster()
assert booster.num_boosted_rounds() == \
booster.best_iteration + early_stopping_rounds + 1
@ -280,20 +281,20 @@ class TestCallbacks:
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4
warning_check = nullcontext()
# learning_rates as a list
# init eta with 0 to check whether learning_rates work
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
'objective': 'binary:logistic', 'eval_metric': 'error',
'tree_method': tree_method}
evals_result = {}
with warning_check:
bst = xgb.train(param, dtrain, num_round, watchlist,
callbacks=[scheduler([
0.8, 0.7, 0.6, 0.5
])],
evals_result=evals_result)
bst = xgb.train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
evals_result=evals_result,
)
eval_errors_0 = list(map(float, evals_result['eval']['error']))
assert isinstance(bst, xgb.core.Booster)
# validation error should decrease, if eta > 0
@ -304,11 +305,15 @@ class TestCallbacks:
'objective': 'binary:logistic', 'eval_metric': 'error',
'tree_method': tree_method}
evals_result = {}
with warning_check:
bst = xgb.train(param, dtrain, num_round, watchlist,
callbacks=[scheduler(
[0.8, 0.7, 0.6, 0.5])],
evals_result=evals_result)
bst = xgb.train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
evals_result=evals_result,
)
eval_errors_1 = list(map(float, evals_result['eval']['error']))
assert isinstance(bst, xgb.core.Booster)
# validation error should decrease, if learning_rate > 0
@ -320,12 +325,14 @@ class TestCallbacks:
'eval_metric': 'error', 'tree_method': tree_method
}
evals_result = {}
with warning_check:
bst = xgb.train(param, dtrain, num_round, watchlist,
callbacks=[scheduler(
[0, 0, 0, 0]
)],
evals_result=evals_result)
bst = xgb.train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler([0, 0, 0, 0])],
evals_result=evals_result,
)
eval_errors_2 = list(map(float, evals_result['eval']['error']))
assert isinstance(bst, xgb.core.Booster)
# validation error should not decrease, if eta/learning_rate = 0
@ -336,12 +343,14 @@ class TestCallbacks:
return num_boost_round / (ithround + 1)
evals_result = {}
with warning_check:
bst = xgb.train(param, dtrain, num_round, watchlist,
callbacks=[
scheduler(eta_decay)
],
evals_result=evals_result)
bst = xgb.train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler(eta_decay)],
evals_result=evals_result,
)
eval_errors_3 = list(map(float, evals_result['eval']['error']))
assert isinstance(bst, xgb.core.Booster)
@ -351,8 +360,7 @@ class TestCallbacks:
for i in range(1, len(eval_errors_0)):
assert eval_errors_3[i] != eval_errors_2[i]
with warning_check:
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
def run_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
# check decay has effect on leaf output.
@ -378,7 +386,7 @@ class TestCallbacks:
param,
dtrain,
num_round,
watchlist,
evals=watchlist,
callbacks=[scheduler(eta_decay_0)],
)
@ -391,7 +399,7 @@ class TestCallbacks:
param,
dtrain,
num_round,
watchlist,
evals=watchlist,
callbacks=[scheduler(eta_decay_1)],
)
bst_json0 = bst0.save_raw(raw_format="json")
@ -474,3 +482,24 @@ class TestCallbacks:
callbacks=callbacks,
)
assert len(callbacks) == 1
def test_attribute_error(self) -> None:
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
clf = xgb.XGBClassifier(n_estimators=8)
clf.fit(X, y, eval_set=[(X, y)])
with pytest.raises(AttributeError, match="early stopping is used"):
clf.best_iteration
with pytest.raises(AttributeError, match="early stopping is used"):
clf.best_score
booster = clf.get_booster()
with pytest.raises(AttributeError, match="early stopping is used"):
booster.best_iteration
with pytest.raises(AttributeError, match="early stopping is used"):
booster.best_score

View File

@ -173,7 +173,7 @@ class TestInplacePredict:
np.testing.assert_allclose(predt_from_dmatrix, predt_from_array)
with pytest.raises(ValueError):
booster.predict(test, iteration_range=(0, booster.best_iteration + 2))
booster.predict(test, iteration_range=(0, booster.num_boosted_rounds() + 2))
default = booster.predict(test)
@ -181,7 +181,7 @@ class TestInplacePredict:
np.testing.assert_allclose(range_full, default)
range_full = booster.predict(
test, iteration_range=(0, booster.best_iteration + 1)
test, iteration_range=(0, booster.num_boosted_rounds())
)
np.testing.assert_allclose(range_full, default)

View File

@ -100,8 +100,8 @@ class TestTrainingContinuation:
res2 = mean_squared_error(
y_2class,
gbdt_04.predict(
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
)
dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
),
)
assert res1 == res2
@ -112,7 +112,7 @@ class TestTrainingContinuation:
res2 = mean_squared_error(
y_2class,
gbdt_04.predict(
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
)
)
assert res1 == res2
@ -126,7 +126,7 @@ class TestTrainingContinuation:
res1 = gbdt_05.predict(dtrain_5class)
res2 = gbdt_05.predict(
dtrain_5class, iteration_range=(0, gbdt_05.best_iteration + 1)
dtrain_5class, iteration_range=(0, gbdt_05.num_boosted_rounds())
)
np.testing.assert_almost_equal(res1, res2)
@ -138,15 +138,16 @@ class TestTrainingContinuation:
@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_updaters_json(self):
# Picked up from R tests.
updaters = 'grow_colmaker,prune,refresh'
updaters = "grow_colmaker,prune,refresh"
params = self.generate_parameters()
for p in params:
p['updater'] = updaters
p["updater"] = updaters
self.run_training_continuation(params[0], params[1], params[2])
@pytest.mark.skipif(**tm.no_sklearn())
def test_changed_parameter(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
clf = xgb.XGBClassifier(n_estimators=2)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")