Define best_iteration only if early stopping is used. (#9403)
* Define `best_iteration` only if early stopping is used. This is the behavior specified by the document but not honored in the actual code. - Don't set the attributes if there's no early stopping. - Clean up the code for callbacks, and replace assertions with proper exceptions. - Assign the attributes when early stopping `save_best` is used. - Turn the attributes into Python properties. --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
01e00efc53
commit
851cba931e
@ -1,9 +1,9 @@
|
|||||||
'''
|
"""
|
||||||
Demo for using and defining callback functions
|
Demo for using and defining callback functions
|
||||||
==============================================
|
==============================================
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
'''
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -17,10 +17,11 @@ import xgboost as xgb
|
|||||||
|
|
||||||
|
|
||||||
class Plotting(xgb.callback.TrainingCallback):
|
class Plotting(xgb.callback.TrainingCallback):
|
||||||
'''Plot evaluation result during training. Only for demonstration purpose as it's quite
|
"""Plot evaluation result during training. Only for demonstration purpose as it's quite
|
||||||
slow to draw.
|
slow to draw.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, rounds):
|
def __init__(self, rounds):
|
||||||
self.fig = plt.figure()
|
self.fig = plt.figure()
|
||||||
self.ax = self.fig.add_subplot(111)
|
self.ax = self.fig.add_subplot(111)
|
||||||
@ -31,16 +32,16 @@ class Plotting(xgb.callback.TrainingCallback):
|
|||||||
plt.ion()
|
plt.ion()
|
||||||
|
|
||||||
def _get_key(self, data, metric):
|
def _get_key(self, data, metric):
|
||||||
return f'{data}-{metric}'
|
return f"{data}-{metric}"
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, evals_log):
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
'''Update the plot.'''
|
"""Update the plot."""
|
||||||
if not self.lines:
|
if not self.lines:
|
||||||
for data, metric in evals_log.items():
|
for data, metric in evals_log.items():
|
||||||
for metric_name, log in metric.items():
|
for metric_name, log in metric.items():
|
||||||
key = self._get_key(data, metric_name)
|
key = self._get_key(data, metric_name)
|
||||||
expanded = log + [0] * (self.rounds - len(log))
|
expanded = log + [0] * (self.rounds - len(log))
|
||||||
self.lines[key], = self.ax.plot(self.x, expanded, label=key)
|
(self.lines[key],) = self.ax.plot(self.x, expanded, label=key)
|
||||||
self.ax.legend()
|
self.ax.legend()
|
||||||
else:
|
else:
|
||||||
# https://pythonspot.com/matplotlib-update-plot/
|
# https://pythonspot.com/matplotlib-update-plot/
|
||||||
@ -55,8 +56,8 @@ class Plotting(xgb.callback.TrainingCallback):
|
|||||||
|
|
||||||
|
|
||||||
def custom_callback():
|
def custom_callback():
|
||||||
'''Demo for defining a custom callback function that plots evaluation result during
|
"""Demo for defining a custom callback function that plots evaluation result during
|
||||||
training.'''
|
training."""
|
||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)
|
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)
|
||||||
|
|
||||||
@ -69,15 +70,16 @@ def custom_callback():
|
|||||||
# Pass it to the `callbacks` parameter as a list.
|
# Pass it to the `callbacks` parameter as a list.
|
||||||
xgb.train(
|
xgb.train(
|
||||||
{
|
{
|
||||||
'objective': 'binary:logistic',
|
"objective": "binary:logistic",
|
||||||
'eval_metric': ['error', 'rmse'],
|
"eval_metric": ["error", "rmse"],
|
||||||
'tree_method': 'hist',
|
"tree_method": "hist",
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
},
|
},
|
||||||
D_train,
|
D_train,
|
||||||
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
evals=[(D_train, "Train"), (D_valid, "Valid")],
|
||||||
num_boost_round=num_boost_round,
|
num_boost_round=num_boost_round,
|
||||||
callbacks=[plotting])
|
callbacks=[plotting],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_point_callback():
|
def check_point_callback():
|
||||||
@ -90,10 +92,10 @@ def check_point_callback():
|
|||||||
if i == 0:
|
if i == 0:
|
||||||
continue
|
continue
|
||||||
if as_pickle:
|
if as_pickle:
|
||||||
path = os.path.join(tmpdir, 'model_' + str(i) + '.pkl')
|
path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
|
||||||
else:
|
else:
|
||||||
path = os.path.join(tmpdir, 'model_' + str(i) + '.json')
|
path = os.path.join(tmpdir, "model_" + str(i) + ".json")
|
||||||
assert(os.path.exists(path))
|
assert os.path.exists(path)
|
||||||
|
|
||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
m = xgb.DMatrix(X, y)
|
m = xgb.DMatrix(X, y)
|
||||||
@ -101,31 +103,36 @@ def check_point_callback():
|
|||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
# Use callback class from xgboost.callback
|
# Use callback class from xgboost.callback
|
||||||
# Feel free to subclass/customize it to suit your need.
|
# Feel free to subclass/customize it to suit your need.
|
||||||
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
check_point = xgb.callback.TrainingCheckPoint(
|
||||||
iterations=rounds,
|
directory=tmpdir, iterations=rounds, name="model"
|
||||||
name='model')
|
)
|
||||||
xgb.train({'objective': 'binary:logistic'}, m,
|
xgb.train(
|
||||||
num_boost_round=10,
|
{"objective": "binary:logistic"},
|
||||||
verbose_eval=False,
|
m,
|
||||||
callbacks=[check_point])
|
num_boost_round=10,
|
||||||
|
verbose_eval=False,
|
||||||
|
callbacks=[check_point],
|
||||||
|
)
|
||||||
check(False)
|
check(False)
|
||||||
|
|
||||||
# This version of checkpoint saves everything including parameters and
|
# This version of checkpoint saves everything including parameters and
|
||||||
# model. See: doc/tutorials/saving_model.rst
|
# model. See: doc/tutorials/saving_model.rst
|
||||||
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
check_point = xgb.callback.TrainingCheckPoint(
|
||||||
iterations=rounds,
|
directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
|
||||||
as_pickle=True,
|
)
|
||||||
name='model')
|
xgb.train(
|
||||||
xgb.train({'objective': 'binary:logistic'}, m,
|
{"objective": "binary:logistic"},
|
||||||
num_boost_round=10,
|
m,
|
||||||
verbose_eval=False,
|
num_boost_round=10,
|
||||||
callbacks=[check_point])
|
verbose_eval=False,
|
||||||
|
callbacks=[check_point],
|
||||||
|
)
|
||||||
check(True)
|
check(True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--plot', default=1, type=int)
|
parser.add_argument("--plot", default=1, type=int)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
check_point_callback()
|
check_point_callback()
|
||||||
|
|||||||
@ -37,3 +37,7 @@ The sliced model is a copy of selected trees, that means the model itself is imm
|
|||||||
during slicing. This feature is the basis of `save_best` option in early stopping
|
during slicing. This feature is the basis of `save_best` option in early stopping
|
||||||
callback. See :ref:`sphx_glr_python_examples_individual_trees.py` for a worked example on
|
callback. See :ref:`sphx_glr_python_examples_individual_trees.py` for a worked example on
|
||||||
how to combine prediction with sliced trees.
|
how to combine prediction with sliced trees.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The returned model slice doesn't contain attributes like :py:class:`~xgboost.Booster.best_iteration` and :py:class:`~xgboost.Booster.best_score`.
|
||||||
|
|||||||
@ -134,13 +134,17 @@ class CallbackContainer:
|
|||||||
is_cv: bool = False,
|
is_cv: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.callbacks = set(callbacks)
|
self.callbacks = set(callbacks)
|
||||||
if metric is not None:
|
for cb in callbacks:
|
||||||
msg = (
|
if not isinstance(cb, TrainingCallback):
|
||||||
"metric must be callable object for monitoring. For "
|
raise TypeError("callback must be an instance of `TrainingCallback`.")
|
||||||
+ "builtin metrics, passing them in training parameter"
|
|
||||||
+ " will invoke monitor automatically."
|
msg = (
|
||||||
)
|
"metric must be callable object for monitoring. For builtin metrics"
|
||||||
assert callable(metric), msg
|
", passing them in training parameter invokes monitor automatically."
|
||||||
|
)
|
||||||
|
if metric is not None and not callable(metric):
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
|
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
|
||||||
self._output_margin = output_margin
|
self._output_margin = output_margin
|
||||||
@ -170,16 +174,6 @@ class CallbackContainer:
|
|||||||
else:
|
else:
|
||||||
assert isinstance(model, Booster), msg
|
assert isinstance(model, Booster), msg
|
||||||
|
|
||||||
if not self.is_cv:
|
|
||||||
if model.attr("best_score") is not None:
|
|
||||||
model.best_score = float(cast(str, model.attr("best_score")))
|
|
||||||
model.best_iteration = int(cast(str, model.attr("best_iteration")))
|
|
||||||
else:
|
|
||||||
# Due to compatibility with version older than 1.4, these attributes are
|
|
||||||
# added to Python object even if early stopping is not used.
|
|
||||||
model.best_iteration = model.num_boosted_rounds() - 1
|
|
||||||
model.set_attr(best_iteration=str(model.best_iteration))
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def before_iteration(
|
def before_iteration(
|
||||||
@ -267,9 +261,14 @@ class LearningRateScheduler(TrainingCallback):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, learning_rates: Union[Callable[[int], float], Sequence[float]]
|
self, learning_rates: Union[Callable[[int], float], Sequence[float]]
|
||||||
) -> None:
|
) -> None:
|
||||||
assert callable(learning_rates) or isinstance(
|
if not callable(learning_rates) and not isinstance(
|
||||||
learning_rates, collections.abc.Sequence
|
learning_rates, collections.abc.Sequence
|
||||||
)
|
):
|
||||||
|
raise TypeError(
|
||||||
|
"Invalid learning rates, expecting callable or sequence, got: "
|
||||||
|
f"{type(learning_rates)}"
|
||||||
|
)
|
||||||
|
|
||||||
if callable(learning_rates):
|
if callable(learning_rates):
|
||||||
self.learning_rates = learning_rates
|
self.learning_rates = learning_rates
|
||||||
else:
|
else:
|
||||||
@ -302,24 +301,28 @@ class EarlyStopping(TrainingCallback):
|
|||||||
save_best :
|
save_best :
|
||||||
Whether training should return the best model or the last model.
|
Whether training should return the best model or the last model.
|
||||||
min_delta :
|
min_delta :
|
||||||
Minimum absolute change in score to be qualified as an improvement.
|
|
||||||
|
|
||||||
.. versionadded:: 1.5.0
|
.. versionadded:: 1.5.0
|
||||||
|
|
||||||
.. code-block:: python
|
Minimum absolute change in score to be qualified as an improvement.
|
||||||
|
|
||||||
es = xgboost.callback.EarlyStopping(
|
Examples
|
||||||
rounds=2,
|
--------
|
||||||
min_delta=1e-3,
|
|
||||||
save_best=True,
|
|
||||||
maximize=False,
|
|
||||||
data_name="validation_0",
|
|
||||||
metric_name="mlogloss",
|
|
||||||
)
|
|
||||||
clf = xgboost.XGBClassifier(tree_method="gpu_hist", callbacks=[es])
|
|
||||||
|
|
||||||
X, y = load_digits(return_X_y=True)
|
.. code-block:: python
|
||||||
clf.fit(X, y, eval_set=[(X, y)])
|
|
||||||
|
es = xgboost.callback.EarlyStopping(
|
||||||
|
rounds=2,
|
||||||
|
min_delta=1e-3,
|
||||||
|
save_best=True,
|
||||||
|
maximize=False,
|
||||||
|
data_name="validation_0",
|
||||||
|
metric_name="mlogloss",
|
||||||
|
)
|
||||||
|
clf = xgboost.XGBClassifier(tree_method="hist", device="cuda", callbacks=[es])
|
||||||
|
|
||||||
|
X, y = load_digits(return_X_y=True)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
@ -363,7 +366,7 @@ class EarlyStopping(TrainingCallback):
|
|||||||
return numpy.greater(get_s(new) - self._min_delta, get_s(best))
|
return numpy.greater(get_s(new) - self._min_delta, get_s(best))
|
||||||
|
|
||||||
def minimize(new: _Score, best: _Score) -> bool:
|
def minimize(new: _Score, best: _Score) -> bool:
|
||||||
"""New score should be smaller than the old one."""
|
"""New score should be lesser than the old one."""
|
||||||
return numpy.greater(get_s(best) - self._min_delta, get_s(new))
|
return numpy.greater(get_s(best) - self._min_delta, get_s(new))
|
||||||
|
|
||||||
if self.maximize is None:
|
if self.maximize is None:
|
||||||
@ -419,38 +422,53 @@ class EarlyStopping(TrainingCallback):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
epoch += self.starting_round # training continuation
|
epoch += self.starting_round # training continuation
|
||||||
msg = "Must have at least 1 validation dataset for early stopping."
|
msg = "Must have at least 1 validation dataset for early stopping."
|
||||||
assert len(evals_log.keys()) >= 1, msg
|
if len(evals_log.keys()) < 1:
|
||||||
data_name = ""
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Get data name
|
||||||
if self.data:
|
if self.data:
|
||||||
for d, _ in evals_log.items():
|
data_name = self.data
|
||||||
if d == self.data:
|
|
||||||
data_name = d
|
|
||||||
if not data_name:
|
|
||||||
raise ValueError("No dataset named:", self.data)
|
|
||||||
else:
|
else:
|
||||||
# Use the last one as default.
|
# Use the last one as default.
|
||||||
data_name = list(evals_log.keys())[-1]
|
data_name = list(evals_log.keys())[-1]
|
||||||
assert isinstance(data_name, str) and data_name
|
if data_name not in evals_log:
|
||||||
|
raise ValueError(f"No dataset named: {data_name}")
|
||||||
|
|
||||||
|
if not isinstance(data_name, str):
|
||||||
|
raise TypeError(
|
||||||
|
f"The name of the dataset should be a string. Got: {type(data_name)}"
|
||||||
|
)
|
||||||
data_log = evals_log[data_name]
|
data_log = evals_log[data_name]
|
||||||
|
|
||||||
# Filter out scores that can not be used for early stopping.
|
# Get metric name
|
||||||
if self.metric_name:
|
if self.metric_name:
|
||||||
metric_name = self.metric_name
|
metric_name = self.metric_name
|
||||||
else:
|
else:
|
||||||
# Use last metric by default.
|
# Use last metric by default.
|
||||||
assert isinstance(data_log, collections.OrderedDict)
|
|
||||||
metric_name = list(data_log.keys())[-1]
|
metric_name = list(data_log.keys())[-1]
|
||||||
|
if metric_name not in data_log:
|
||||||
|
raise ValueError(f"No metric named: {metric_name}")
|
||||||
|
|
||||||
|
# The latest score
|
||||||
score = data_log[metric_name][-1]
|
score = data_log[metric_name][-1]
|
||||||
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
||||||
|
|
||||||
def after_training(self, model: _Model) -> _Model:
|
def after_training(self, model: _Model) -> _Model:
|
||||||
|
if not self.save_best:
|
||||||
|
return model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.save_best:
|
best_iteration = model.best_iteration
|
||||||
model = model[: int(model.attr("best_iteration")) + 1]
|
best_score = model.best_score
|
||||||
|
assert best_iteration is not None and best_score is not None
|
||||||
|
model = model[: best_iteration + 1]
|
||||||
|
model.best_iteration = best_iteration
|
||||||
|
model.best_score = best_score
|
||||||
except XGBoostError as e:
|
except XGBoostError as e:
|
||||||
raise XGBoostError(
|
raise XGBoostError(
|
||||||
"`save_best` is not applicable to current booster"
|
"`save_best` is not applicable to the current booster"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -462,8 +480,6 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
||||||
metric :
|
|
||||||
Extra user defined metric.
|
|
||||||
rank :
|
rank :
|
||||||
Which worker should be used for printing the result.
|
Which worker should be used for printing the result.
|
||||||
period :
|
period :
|
||||||
|
|||||||
@ -1890,7 +1890,7 @@ class Booster:
|
|||||||
attr_names = from_cstr_to_pystr(sarr, length)
|
attr_names = from_cstr_to_pystr(sarr, length)
|
||||||
return {n: self.attr(n) for n in attr_names}
|
return {n: self.attr(n) for n in attr_names}
|
||||||
|
|
||||||
def set_attr(self, **kwargs: Optional[str]) -> None:
|
def set_attr(self, **kwargs: Optional[Any]) -> None:
|
||||||
"""Set the attribute of the Booster.
|
"""Set the attribute of the Booster.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -2559,10 +2559,35 @@ class Booster:
|
|||||||
else:
|
else:
|
||||||
raise TypeError("Unknown file type: ", fname)
|
raise TypeError("Unknown file type: ", fname)
|
||||||
|
|
||||||
if self.attr("best_iteration") is not None:
|
@property
|
||||||
self.best_iteration = int(cast(int, self.attr("best_iteration")))
|
def best_iteration(self) -> int:
|
||||||
if self.attr("best_score") is not None:
|
"""The best iteration during training."""
|
||||||
self.best_score = float(cast(float, self.attr("best_score")))
|
best = self.attr("best_iteration")
|
||||||
|
if best is not None:
|
||||||
|
return int(best)
|
||||||
|
|
||||||
|
raise AttributeError(
|
||||||
|
"`best_iteration` is only defined when early stopping is used."
|
||||||
|
)
|
||||||
|
|
||||||
|
@best_iteration.setter
|
||||||
|
def best_iteration(self, iteration: int) -> None:
|
||||||
|
self.set_attr(best_iteration=iteration)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def best_score(self) -> float:
|
||||||
|
"""The best evaluation score during training."""
|
||||||
|
best = self.attr("best_score")
|
||||||
|
if best is not None:
|
||||||
|
return float(best)
|
||||||
|
|
||||||
|
raise AttributeError(
|
||||||
|
"`best_score` is only defined when early stopping is used."
|
||||||
|
)
|
||||||
|
|
||||||
|
@best_score.setter
|
||||||
|
def best_score(self, score: int) -> None:
|
||||||
|
self.set_attr(best_score=score)
|
||||||
|
|
||||||
def num_boosted_rounds(self) -> int:
|
def num_boosted_rounds(self) -> int:
|
||||||
"""Get number of boosted rounds. For gblinear this is reset to 0 after
|
"""Get number of boosted rounds. For gblinear this is reset to 0 after
|
||||||
|
|||||||
@ -230,10 +230,10 @@ __model_doc = f"""
|
|||||||
subsample : Optional[float]
|
subsample : Optional[float]
|
||||||
Subsample ratio of the training instance.
|
Subsample ratio of the training instance.
|
||||||
sampling_method :
|
sampling_method :
|
||||||
Sampling method. Used only by `gpu_hist` tree method.
|
Sampling method. Used only by the GPU version of ``hist`` tree method.
|
||||||
- `uniform`: select random training instances uniformly.
|
- ``uniform``: select random training instances uniformly.
|
||||||
- `gradient_based` select random training instances with higher probability when
|
- ``gradient_based`` select random training instances with higher probability
|
||||||
the gradient and hessian are larger. (cf. CatBoost)
|
when the gradient and hessian are larger. (cf. CatBoost)
|
||||||
colsample_bytree : Optional[float]
|
colsample_bytree : Optional[float]
|
||||||
Subsample ratio of columns when constructing each tree.
|
Subsample ratio of columns when constructing each tree.
|
||||||
colsample_bylevel : Optional[float]
|
colsample_bylevel : Optional[float]
|
||||||
@ -992,12 +992,12 @@ class XGBModel(XGBModelBase):
|
|||||||
X :
|
X :
|
||||||
Feature matrix. See :ref:`py-data` for a list of supported types.
|
Feature matrix. See :ref:`py-data` for a list of supported types.
|
||||||
|
|
||||||
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the
|
When the ``tree_method`` is set to ``hist``, internally, the
|
||||||
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
|
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
|
||||||
for conserving memory. However, this has performance implications when the
|
for conserving memory. However, this has performance implications when the
|
||||||
device of input data is not matched with algorithm. For instance, if the
|
device of input data is not matched with algorithm. For instance, if the
|
||||||
input is a numpy array on CPU but ``gpu_hist`` is used for training, then
|
input is a numpy array on CPU but ``cuda`` is used for training, then the
|
||||||
the data is first processed on CPU then transferred to GPU.
|
data is first processed on CPU then transferred to GPU.
|
||||||
y :
|
y :
|
||||||
Labels
|
Labels
|
||||||
sample_weight :
|
sample_weight :
|
||||||
@ -1279,19 +1279,10 @@ class XGBModel(XGBModelBase):
|
|||||||
)
|
)
|
||||||
return np.array(feature_names)
|
return np.array(feature_names)
|
||||||
|
|
||||||
def _early_stopping_attr(self, attr: str) -> Union[float, int]:
|
|
||||||
booster = self.get_booster()
|
|
||||||
try:
|
|
||||||
return getattr(booster, attr)
|
|
||||||
except AttributeError as e:
|
|
||||||
raise AttributeError(
|
|
||||||
f"`{attr}` in only defined when early stopping is used."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def best_score(self) -> float:
|
def best_score(self) -> float:
|
||||||
"""The best score obtained by early stopping."""
|
"""The best score obtained by early stopping."""
|
||||||
return float(self._early_stopping_attr("best_score"))
|
return self.get_booster().best_score
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def best_iteration(self) -> int:
|
def best_iteration(self) -> int:
|
||||||
@ -1299,7 +1290,7 @@ class XGBModel(XGBModelBase):
|
|||||||
for instance if the best iteration is the first round, then best_iteration is 0.
|
for instance if the best iteration is the first round, then best_iteration is 0.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return int(self._early_stopping_attr("best_iteration"))
|
return self.get_booster().best_iteration
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feature_importances_(self) -> np.ndarray:
|
def feature_importances_(self) -> np.ndarray:
|
||||||
@ -1926,12 +1917,12 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
| 1 | :math:`x_{20}` | :math:`x_{21}` |
|
| 1 | :math:`x_{20}` | :math:`x_{21}` |
|
||||||
+-----+----------------+----------------+
|
+-----+----------------+----------------+
|
||||||
|
|
||||||
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the
|
When the ``tree_method`` is set to ``hist``, internally, the
|
||||||
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
|
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
|
||||||
for conserving memory. However, this has performance implications when the
|
for conserving memory. However, this has performance implications when the
|
||||||
device of input data is not matched with algorithm. For instance, if the
|
device of input data is not matched with algorithm. For instance, if the
|
||||||
input is a numpy array on CPU but ``gpu_hist`` is used for training, then
|
input is a numpy array on CPU but ``cuda`` is used for training, then the
|
||||||
the data is first processed on CPU then transferred to GPU.
|
data is first processed on CPU then transferred to GPU.
|
||||||
y :
|
y :
|
||||||
Labels
|
Labels
|
||||||
group :
|
group :
|
||||||
|
|||||||
@ -28,17 +28,6 @@ from .core import (
|
|||||||
_CVFolds = Sequence["CVPack"]
|
_CVFolds = Sequence["CVPack"]
|
||||||
|
|
||||||
|
|
||||||
def _assert_new_callback(callbacks: Optional[Sequence[TrainingCallback]]) -> None:
|
|
||||||
is_new_callback: bool = not callbacks or all(
|
|
||||||
isinstance(c, TrainingCallback) for c in callbacks
|
|
||||||
)
|
|
||||||
if not is_new_callback:
|
|
||||||
link = "https://xgboost.readthedocs.io/en/latest/python/callbacks.html"
|
|
||||||
raise ValueError(
|
|
||||||
f"Old style callback was removed in version 1.6. See: {link}."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_custom_metric(
|
def _configure_custom_metric(
|
||||||
feval: Optional[Metric], custom_metric: Optional[Metric]
|
feval: Optional[Metric], custom_metric: Optional[Metric]
|
||||||
) -> Optional[Metric]:
|
) -> Optional[Metric]:
|
||||||
@ -170,7 +159,6 @@ def train(
|
|||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
||||||
start_iteration = 0
|
start_iteration = 0
|
||||||
|
|
||||||
_assert_new_callback(callbacks)
|
|
||||||
if verbose_eval:
|
if verbose_eval:
|
||||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||||
callbacks.append(EvaluationMonitor(period=verbose_eval))
|
callbacks.append(EvaluationMonitor(period=verbose_eval))
|
||||||
@ -247,7 +235,7 @@ class _PackedBooster:
|
|||||||
result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
|
result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def set_attr(self, **kwargs: Optional[str]) -> Any:
|
def set_attr(self, **kwargs: Optional[Any]) -> Any:
|
||||||
"""Iterate through folds for setting attributes"""
|
"""Iterate through folds for setting attributes"""
|
||||||
for f in self.cvfolds:
|
for f in self.cvfolds:
|
||||||
f.bst.set_attr(**kwargs)
|
f.bst.set_attr(**kwargs)
|
||||||
@ -274,11 +262,20 @@ class _PackedBooster:
|
|||||||
"""Get best_iteration"""
|
"""Get best_iteration"""
|
||||||
return int(cast(int, self.cvfolds[0].bst.attr("best_iteration")))
|
return int(cast(int, self.cvfolds[0].bst.attr("best_iteration")))
|
||||||
|
|
||||||
|
@best_iteration.setter
|
||||||
|
def best_iteration(self, iteration: int) -> None:
|
||||||
|
"""Get best_iteration"""
|
||||||
|
self.set_attr(best_iteration=iteration)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def best_score(self) -> float:
|
def best_score(self) -> float:
|
||||||
"""Get best_score."""
|
"""Get best_score."""
|
||||||
return float(cast(float, self.cvfolds[0].bst.attr("best_score")))
|
return float(cast(float, self.cvfolds[0].bst.attr("best_score")))
|
||||||
|
|
||||||
|
@best_score.setter
|
||||||
|
def best_score(self, score: float) -> None:
|
||||||
|
self.set_attr(best_score=score)
|
||||||
|
|
||||||
|
|
||||||
def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarray:
|
def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@ -551,7 +548,6 @@ def cv(
|
|||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
|
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
|
||||||
_assert_new_callback(callbacks)
|
|
||||||
|
|
||||||
if verbose_eval:
|
if verbose_eval:
|
||||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||||
|
|||||||
@ -37,6 +37,7 @@ class LintersPaths:
|
|||||||
"demo/rmm_plugin",
|
"demo/rmm_plugin",
|
||||||
"demo/json-model/json_parser.py",
|
"demo/json-model/json_parser.py",
|
||||||
"demo/guide-python/cat_in_the_dat.py",
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
|
"demo/guide-python/callbacks.py",
|
||||||
"demo/guide-python/categorical.py",
|
"demo/guide-python/categorical.py",
|
||||||
"demo/guide-python/feature_weights.py",
|
"demo/guide-python/feature_weights.py",
|
||||||
"demo/guide-python/sklearn_parallel.py",
|
"demo/guide-python/sklearn_parallel.py",
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -104,15 +103,6 @@ class TestCallbacks:
|
|||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
# No early stopping, best_iteration should be set to last epoch
|
|
||||||
booster = xgb.train({'objective': 'binary:logistic',
|
|
||||||
'eval_metric': 'error'}, D_train,
|
|
||||||
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
|
||||||
num_boost_round=10,
|
|
||||||
evals_result=evals_result,
|
|
||||||
verbose_eval=True)
|
|
||||||
assert booster.num_boosted_rounds() - 1 == booster.best_iteration
|
|
||||||
|
|
||||||
def test_early_stopping_custom_eval(self):
|
def test_early_stopping_custom_eval(self):
|
||||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||||
@ -204,8 +194,9 @@ class TestCallbacks:
|
|||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
n_estimators = 100
|
n_estimators = 100
|
||||||
early_stopping_rounds = 5
|
early_stopping_rounds = 5
|
||||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
early_stop = xgb.callback.EarlyStopping(
|
||||||
save_best=True)
|
rounds=early_stopping_rounds, save_best=True
|
||||||
|
)
|
||||||
cls = xgb.XGBClassifier(
|
cls = xgb.XGBClassifier(
|
||||||
n_estimators=n_estimators,
|
n_estimators=n_estimators,
|
||||||
eval_metric=tm.eval_error_metric_skl,
|
eval_metric=tm.eval_error_metric_skl,
|
||||||
@ -216,20 +207,27 @@ class TestCallbacks:
|
|||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(dump) == booster.best_iteration + 1
|
assert len(dump) == booster.best_iteration + 1
|
||||||
|
|
||||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
early_stop = xgb.callback.EarlyStopping(
|
||||||
save_best=True)
|
rounds=early_stopping_rounds, save_best=True
|
||||||
|
)
|
||||||
cls = xgb.XGBClassifier(
|
cls = xgb.XGBClassifier(
|
||||||
booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl
|
booster="gblinear",
|
||||||
|
n_estimators=10,
|
||||||
|
eval_metric=tm.eval_error_metric_skl,
|
||||||
|
callbacks=[early_stop],
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
|
cls.fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
# No error
|
# No error
|
||||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
save_best=False)
|
save_best=False)
|
||||||
xgb.XGBClassifier(
|
xgb.XGBClassifier(
|
||||||
booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl
|
booster="gblinear",
|
||||||
).fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
|
n_estimators=10,
|
||||||
|
eval_metric=tm.eval_error_metric_skl,
|
||||||
|
callbacks=[early_stop],
|
||||||
|
).fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
def test_early_stopping_continuation(self):
|
def test_early_stopping_continuation(self):
|
||||||
from sklearn.datasets import load_breast_cancer
|
from sklearn.datasets import load_breast_cancer
|
||||||
@ -252,8 +250,11 @@ class TestCallbacks:
|
|||||||
cls.load_model(path)
|
cls.load_model(path)
|
||||||
assert cls._Booster is not None
|
assert cls._Booster is not None
|
||||||
early_stopping_rounds = 3
|
early_stopping_rounds = 3
|
||||||
cls.set_params(eval_metric=tm.eval_error_metric_skl)
|
cls.set_params(
|
||||||
cls.fit(X, y, eval_set=[(X, y)], early_stopping_rounds=early_stopping_rounds)
|
eval_metric=tm.eval_error_metric_skl,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
)
|
||||||
|
cls.fit(X, y, eval_set=[(X, y)])
|
||||||
booster = cls.get_booster()
|
booster = cls.get_booster()
|
||||||
assert booster.num_boosted_rounds() == \
|
assert booster.num_boosted_rounds() == \
|
||||||
booster.best_iteration + early_stopping_rounds + 1
|
booster.best_iteration + early_stopping_rounds + 1
|
||||||
@ -280,20 +281,20 @@ class TestCallbacks:
|
|||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
num_round = 4
|
num_round = 4
|
||||||
|
|
||||||
warning_check = nullcontext()
|
|
||||||
|
|
||||||
# learning_rates as a list
|
# learning_rates as a list
|
||||||
# init eta with 0 to check whether learning_rates work
|
# init eta with 0 to check whether learning_rates work
|
||||||
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
|
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
|
||||||
'objective': 'binary:logistic', 'eval_metric': 'error',
|
'objective': 'binary:logistic', 'eval_metric': 'error',
|
||||||
'tree_method': tree_method}
|
'tree_method': tree_method}
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
with warning_check:
|
bst = xgb.train(
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
param,
|
||||||
callbacks=[scheduler([
|
dtrain,
|
||||||
0.8, 0.7, 0.6, 0.5
|
num_round,
|
||||||
])],
|
evals=watchlist,
|
||||||
evals_result=evals_result)
|
callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
|
||||||
|
evals_result=evals_result,
|
||||||
|
)
|
||||||
eval_errors_0 = list(map(float, evals_result['eval']['error']))
|
eval_errors_0 = list(map(float, evals_result['eval']['error']))
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
# validation error should decrease, if eta > 0
|
# validation error should decrease, if eta > 0
|
||||||
@ -304,11 +305,15 @@ class TestCallbacks:
|
|||||||
'objective': 'binary:logistic', 'eval_metric': 'error',
|
'objective': 'binary:logistic', 'eval_metric': 'error',
|
||||||
'tree_method': tree_method}
|
'tree_method': tree_method}
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
with warning_check:
|
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
bst = xgb.train(
|
||||||
callbacks=[scheduler(
|
param,
|
||||||
[0.8, 0.7, 0.6, 0.5])],
|
dtrain,
|
||||||
evals_result=evals_result)
|
num_round,
|
||||||
|
evals=watchlist,
|
||||||
|
callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
|
||||||
|
evals_result=evals_result,
|
||||||
|
)
|
||||||
eval_errors_1 = list(map(float, evals_result['eval']['error']))
|
eval_errors_1 = list(map(float, evals_result['eval']['error']))
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
# validation error should decrease, if learning_rate > 0
|
# validation error should decrease, if learning_rate > 0
|
||||||
@ -320,12 +325,14 @@ class TestCallbacks:
|
|||||||
'eval_metric': 'error', 'tree_method': tree_method
|
'eval_metric': 'error', 'tree_method': tree_method
|
||||||
}
|
}
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
with warning_check:
|
bst = xgb.train(
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
param,
|
||||||
callbacks=[scheduler(
|
dtrain,
|
||||||
[0, 0, 0, 0]
|
num_round,
|
||||||
)],
|
evals=watchlist,
|
||||||
evals_result=evals_result)
|
callbacks=[scheduler([0, 0, 0, 0])],
|
||||||
|
evals_result=evals_result,
|
||||||
|
)
|
||||||
eval_errors_2 = list(map(float, evals_result['eval']['error']))
|
eval_errors_2 = list(map(float, evals_result['eval']['error']))
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
# validation error should not decrease, if eta/learning_rate = 0
|
# validation error should not decrease, if eta/learning_rate = 0
|
||||||
@ -336,12 +343,14 @@ class TestCallbacks:
|
|||||||
return num_boost_round / (ithround + 1)
|
return num_boost_round / (ithround + 1)
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
with warning_check:
|
bst = xgb.train(
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
param,
|
||||||
callbacks=[
|
dtrain,
|
||||||
scheduler(eta_decay)
|
num_round,
|
||||||
],
|
evals=watchlist,
|
||||||
evals_result=evals_result)
|
callbacks=[scheduler(eta_decay)],
|
||||||
|
evals_result=evals_result,
|
||||||
|
)
|
||||||
eval_errors_3 = list(map(float, evals_result['eval']['error']))
|
eval_errors_3 = list(map(float, evals_result['eval']['error']))
|
||||||
|
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
@ -351,8 +360,7 @@ class TestCallbacks:
|
|||||||
for i in range(1, len(eval_errors_0)):
|
for i in range(1, len(eval_errors_0)):
|
||||||
assert eval_errors_3[i] != eval_errors_2[i]
|
assert eval_errors_3[i] != eval_errors_2[i]
|
||||||
|
|
||||||
with warning_check:
|
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
|
||||||
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
|
|
||||||
|
|
||||||
def run_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
|
def run_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
|
||||||
# check decay has effect on leaf output.
|
# check decay has effect on leaf output.
|
||||||
@ -378,7 +386,7 @@ class TestCallbacks:
|
|||||||
param,
|
param,
|
||||||
dtrain,
|
dtrain,
|
||||||
num_round,
|
num_round,
|
||||||
watchlist,
|
evals=watchlist,
|
||||||
callbacks=[scheduler(eta_decay_0)],
|
callbacks=[scheduler(eta_decay_0)],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -391,7 +399,7 @@ class TestCallbacks:
|
|||||||
param,
|
param,
|
||||||
dtrain,
|
dtrain,
|
||||||
num_round,
|
num_round,
|
||||||
watchlist,
|
evals=watchlist,
|
||||||
callbacks=[scheduler(eta_decay_1)],
|
callbacks=[scheduler(eta_decay_1)],
|
||||||
)
|
)
|
||||||
bst_json0 = bst0.save_raw(raw_format="json")
|
bst_json0 = bst0.save_raw(raw_format="json")
|
||||||
@ -474,3 +482,24 @@ class TestCallbacks:
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
assert len(callbacks) == 1
|
assert len(callbacks) == 1
|
||||||
|
|
||||||
|
def test_attribute_error(self) -> None:
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(n_estimators=8)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError, match="early stopping is used"):
|
||||||
|
clf.best_iteration
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError, match="early stopping is used"):
|
||||||
|
clf.best_score
|
||||||
|
|
||||||
|
booster = clf.get_booster()
|
||||||
|
with pytest.raises(AttributeError, match="early stopping is used"):
|
||||||
|
booster.best_iteration
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError, match="early stopping is used"):
|
||||||
|
booster.best_score
|
||||||
|
|||||||
@ -173,7 +173,7 @@ class TestInplacePredict:
|
|||||||
np.testing.assert_allclose(predt_from_dmatrix, predt_from_array)
|
np.testing.assert_allclose(predt_from_dmatrix, predt_from_array)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
booster.predict(test, iteration_range=(0, booster.best_iteration + 2))
|
booster.predict(test, iteration_range=(0, booster.num_boosted_rounds() + 2))
|
||||||
|
|
||||||
default = booster.predict(test)
|
default = booster.predict(test)
|
||||||
|
|
||||||
@ -181,7 +181,7 @@ class TestInplacePredict:
|
|||||||
np.testing.assert_allclose(range_full, default)
|
np.testing.assert_allclose(range_full, default)
|
||||||
|
|
||||||
range_full = booster.predict(
|
range_full = booster.predict(
|
||||||
test, iteration_range=(0, booster.best_iteration + 1)
|
test, iteration_range=(0, booster.num_boosted_rounds())
|
||||||
)
|
)
|
||||||
np.testing.assert_allclose(range_full, default)
|
np.testing.assert_allclose(range_full, default)
|
||||||
|
|
||||||
|
|||||||
@ -100,8 +100,8 @@ class TestTrainingContinuation:
|
|||||||
res2 = mean_squared_error(
|
res2 = mean_squared_error(
|
||||||
y_2class,
|
y_2class,
|
||||||
gbdt_04.predict(
|
gbdt_04.predict(
|
||||||
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
|
dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
assert res1 == res2
|
assert res1 == res2
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ class TestTrainingContinuation:
|
|||||||
res2 = mean_squared_error(
|
res2 = mean_squared_error(
|
||||||
y_2class,
|
y_2class,
|
||||||
gbdt_04.predict(
|
gbdt_04.predict(
|
||||||
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
|
dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert res1 == res2
|
assert res1 == res2
|
||||||
@ -126,7 +126,7 @@ class TestTrainingContinuation:
|
|||||||
|
|
||||||
res1 = gbdt_05.predict(dtrain_5class)
|
res1 = gbdt_05.predict(dtrain_5class)
|
||||||
res2 = gbdt_05.predict(
|
res2 = gbdt_05.predict(
|
||||||
dtrain_5class, iteration_range=(0, gbdt_05.best_iteration + 1)
|
dtrain_5class, iteration_range=(0, gbdt_05.num_boosted_rounds())
|
||||||
)
|
)
|
||||||
np.testing.assert_almost_equal(res1, res2)
|
np.testing.assert_almost_equal(res1, res2)
|
||||||
|
|
||||||
@ -138,15 +138,16 @@ class TestTrainingContinuation:
|
|||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_training_continuation_updaters_json(self):
|
def test_training_continuation_updaters_json(self):
|
||||||
# Picked up from R tests.
|
# Picked up from R tests.
|
||||||
updaters = 'grow_colmaker,prune,refresh'
|
updaters = "grow_colmaker,prune,refresh"
|
||||||
params = self.generate_parameters()
|
params = self.generate_parameters()
|
||||||
for p in params:
|
for p in params:
|
||||||
p['updater'] = updaters
|
p["updater"] = updaters
|
||||||
self.run_training_continuation(params[0], params[1], params[2])
|
self.run_training_continuation(params[0], params[1], params[2])
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_changed_parameter(self):
|
def test_changed_parameter(self):
|
||||||
from sklearn.datasets import load_breast_cancer
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
|
||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
clf = xgb.XGBClassifier(n_estimators=2)
|
clf = xgb.XGBClassifier(n_estimators=2)
|
||||||
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
|
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user