[breaking] Remove deprecated parameters in the skl interface. (#9986)

This commit is contained in:
Jiaming Yuan 2024-01-15 20:40:05 +08:00 committed by GitHub
parent 2de85d3241
commit 0798e36d73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 418 additions and 462 deletions

View File

@ -16,14 +16,14 @@ def training_continuation(tmpdir: str, use_pickle: bool) -> None:
"""Basic training continuation."""
# Train 128 iterations in 1 session
X, y = load_breast_cancer(return_X_y=True)
clf = xgboost.XGBClassifier(n_estimators=128)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
clf = xgboost.XGBClassifier(n_estimators=128, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)])
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
# Train 128 iterations in 2 sessions, with the first one runs for 32 iterations and
# the second one runs for 96 iterations
clf = xgboost.XGBClassifier(n_estimators=32)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
clf = xgboost.XGBClassifier(n_estimators=32, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)])
assert clf.get_booster().num_boosted_rounds() == 32
# load back the model, this could be a checkpoint
@ -39,8 +39,8 @@ def training_continuation(tmpdir: str, use_pickle: bool) -> None:
loaded = xgboost.XGBClassifier()
loaded.load_model(path)
clf = xgboost.XGBClassifier(n_estimators=128 - 32)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", xgb_model=loaded)
clf = xgboost.XGBClassifier(n_estimators=128 - 32, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
@ -56,19 +56,24 @@ def training_continuation_early_stop(tmpdir: str, use_pickle: bool) -> None:
n_estimators = 512
X, y = load_breast_cancer(return_X_y=True)
clf = xgboost.XGBClassifier(n_estimators=n_estimators)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", callbacks=[early_stop])
clf = xgboost.XGBClassifier(
n_estimators=n_estimators, eval_metric="logloss", callbacks=[early_stop]
)
clf.fit(X, y, eval_set=[(X, y)])
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
best = clf.best_iteration
# Train 512 iterations in 2 sessions, with the first one runs for 128 iterations and
# the second one runs until early stop.
clf = xgboost.XGBClassifier(n_estimators=128)
clf = xgboost.XGBClassifier(
n_estimators=128, eval_metric="logloss", callbacks=[early_stop]
)
# Reinitialize the early stop callback
early_stop = xgboost.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", callbacks=[early_stop])
clf.set_params(callbacks=[early_stop])
clf.fit(X, y, eval_set=[(X, y)])
assert clf.get_booster().num_boosted_rounds() == 128
# load back the model, this could be a checkpoint
@ -87,13 +92,13 @@ def training_continuation_early_stop(tmpdir: str, use_pickle: bool) -> None:
early_stop = xgboost.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
clf = xgboost.XGBClassifier(n_estimators=n_estimators - 128)
clf = xgboost.XGBClassifier(
n_estimators=n_estimators - 128, eval_metric="logloss", callbacks=[early_stop]
)
clf.fit(
X,
y,
eval_set=[(X, y)],
eval_metric="logloss",
callbacks=[early_stop],
xgb_model=loaded,
)

View File

@ -16,30 +16,35 @@ labels, y = np.unique(y, return_inverse=True)
X_train, X_test = X[:1600], X[1600:]
y_train, y_test = y[:1600], y[1600:]
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
param_dist = {"objective": "binary:logistic", "n_estimators": 2}
clf = xgb.XGBModel(**param_dist)
clf = xgb.XGBModel(
**param_dist,
eval_metric="logloss",
)
# Or you can use: clf = xgb.XGBClassifier(**param_dist)
clf.fit(X_train, y_train,
clf.fit(
X_train,
y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss',
verbose=True)
verbose=True,
)
# Load evals result by calling the evals_result() function
evals_result = clf.evals_result()
print('Access logloss metric directly from validation_0:')
print(evals_result['validation_0']['logloss'])
print("Access logloss metric directly from validation_0:")
print(evals_result["validation_0"]["logloss"])
print('')
print('Access metrics through a loop:')
print("")
print("Access metrics through a loop:")
for e_name, e_mtrs in evals_result.items():
print('- {}'.format(e_name))
print("- {}".format(e_name))
for e_mtr_name, e_mtr_vals in e_mtrs.items():
print(' - {}'.format(e_mtr_name))
print(' - {}'.format(e_mtr_vals))
print(" - {}".format(e_mtr_name))
print(" - {}".format(e_mtr_vals))
print('')
print('Access complete dict:')
print("")
print("Access complete dict:")
print(evals_result)

View File

@ -1,4 +1,4 @@
'''
"""
Collection of examples for using sklearn interface
==================================================
@ -8,7 +8,7 @@ For an introduction to XGBoost's scikit-learn estimator interface, see
Created on 1 Apr 2015
@author: Jamie Hall
'''
"""
import pickle
import numpy as np
@ -22,8 +22,8 @@ rng = np.random.RandomState(31337)
print("Zeros and Ones from the Digits dataset: binary classification")
digits = load_digits(n_class=2)
y = digits['target']
X = digits['data']
y = digits["target"]
X = digits["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X):
xgb_model = xgb.XGBClassifier(n_jobs=1).fit(X[train_index], y[train_index])
@ -33,8 +33,8 @@ for train_index, test_index in kf.split(X):
print("Iris: multiclass classification")
iris = load_iris()
y = iris['target']
X = iris['data']
y = iris["target"]
X = iris["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X):
xgb_model = xgb.XGBClassifier(n_jobs=1).fit(X[train_index], y[train_index])
@ -53,9 +53,13 @@ for train_index, test_index in kf.split(X):
print("Parameter optimization")
xgb_model = xgb.XGBRegressor(n_jobs=1)
clf = GridSearchCV(xgb_model,
{'max_depth': [2, 4],
'n_estimators': [50, 100]}, verbose=1, n_jobs=1, cv=3)
clf = GridSearchCV(
xgb_model,
{"max_depth": [2, 4], "n_estimators": [50, 100]},
verbose=1,
n_jobs=1,
cv=3,
)
clf.fit(X, y)
print(clf.best_score_)
print(clf.best_params_)
@ -69,9 +73,8 @@ print(np.allclose(clf.predict(X), clf2.predict(X)))
# Early-stopping
X = digits['data']
y = digits['target']
X = digits["data"]
y = digits["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = xgb.XGBClassifier(n_jobs=1)
clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc",
eval_set=[(X_test, y_test)])
clf = xgb.XGBClassifier(n_jobs=1, early_stopping_rounds=10, eval_metric="auc")
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])

View File

@ -12,6 +12,7 @@ import xgboost as xgb
if __name__ == "__main__":
print("Parallel Parameter optimization")
X, y = fetch_california_housing(return_X_y=True)
# Make sure the number of threads is balanced.
xgb_model = xgb.XGBRegressor(
n_jobs=multiprocessing.cpu_count() // 2, tree_method="hist"
)

View File

@ -123,11 +123,11 @@ monitor our model's performance. As mentioned above, the default metric for ``S
elements = np.power(np.log1p(y) - np.log1p(predt), 2)
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))
Since we are demonstrating in Python, the metric or objective need not be a function,
any callable object should suffice. Similar to the objective function, our metric also
accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric itself and a
floating point value as the result. After passing it into XGBoost as argument of ``feval``
parameter:
Since we are demonstrating in Python, the metric or objective need not be a function, any
callable object should suffice. Similar to the objective function, our metric also
accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric itself and
a floating point value as the result. After passing it into XGBoost as argument of
``custom_metric`` parameter:
.. code-block:: python
@ -136,7 +136,7 @@ parameter:
dtrain=dtrain,
num_boost_round=10,
obj=squared_log,
feval=rmsle,
custom_metric=rmsle,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results)

View File

@ -61,7 +61,7 @@ from typing import (
import numpy
from xgboost import collective, config
from xgboost._typing import _T, FeatureNames, FeatureTypes, ModelIn
from xgboost._typing import _T, FeatureNames, FeatureTypes
from xgboost.callback import TrainingCallback
from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance
from xgboost.core import (
@ -1774,14 +1774,11 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
eval_metric: Optional[Union[str, Sequence[str], Metric]],
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
early_stopping_rounds: Optional[int],
verbose: Union[int, bool],
xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection],
callbacks: Optional[Sequence[TrainingCallback]],
) -> _DaskCollection:
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
@ -1809,9 +1806,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
model, metric, params, early_stopping_rounds, callbacks = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
model, metric, params = self._configure_fit(xgb_model, params)
results = await self.client.sync(
_train_async,
asynchronous=True,
@ -1826,8 +1821,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
feval=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
early_stopping_rounds=self.early_stopping_rounds,
callbacks=self.callbacks,
xgb_model=model,
)
self._Booster = results["booster"]
@ -1844,14 +1839,11 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBRegressor":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -1871,14 +1863,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
eval_metric: Optional[Union[str, Sequence[str], Metric]],
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
early_stopping_rounds: Optional[int],
verbose: Union[int, bool],
xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection],
callbacks: Optional[Sequence[TrainingCallback]],
) -> "DaskXGBClassifier":
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
@ -1924,9 +1913,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
model, metric, params, early_stopping_rounds, callbacks = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
model, metric, params = self._configure_fit(xgb_model, params)
results = await self.client.sync(
_train_async,
asynchronous=True,
@ -1941,8 +1928,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
feval=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
early_stopping_rounds=self.early_stopping_rounds,
callbacks=self.callbacks,
xgb_model=model,
)
self._Booster = results["booster"]
@ -1960,14 +1947,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBClassifier":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -2063,7 +2047,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
if callable(objective):
raise ValueError("Custom objective function not supported by XGBRanker.")
super().__init__(objective=objective, kwargs=kwargs)
super().__init__(objective=objective, **kwargs)
async def _fit_async(
self,
@ -2078,12 +2062,9 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
eval_group: Optional[Sequence[_DaskCollection]],
eval_qid: Optional[Sequence[_DaskCollection]],
eval_metric: Optional[Union[str, Sequence[str], Metric]],
early_stopping_rounds: Optional[int],
verbose: Union[int, bool],
xgb_model: Optional[Union[XGBModel, Booster]],
feature_weights: Optional[_DaskCollection],
callbacks: Optional[Sequence[TrainingCallback]],
) -> "DaskXGBRanker":
msg = "Use `qid` instead of `group` on dask interface."
if not (group is None and eval_group is None):
@ -2111,14 +2092,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
if eval_metric is not None:
if callable(eval_metric):
raise ValueError(
"Custom evaluation metric is not yet supported for XGBRanker."
)
model, metric, params, early_stopping_rounds, callbacks = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
model, metric, params = self._configure_fit(xgb_model, params)
results = await self.client.sync(
_train_async,
asynchronous=True,
@ -2133,8 +2107,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
feval=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
early_stopping_rounds=self.early_stopping_rounds,
callbacks=self.callbacks,
xgb_model=model,
)
self._Booster = results["booster"]
@ -2155,14 +2129,11 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_group: Optional[Sequence[_DaskCollection]] = None,
eval_qid: Optional[Sequence[_DaskCollection]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = False,
xgb_model: Optional[Union[XGBModel, Booster]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBRanker":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -2221,18 +2192,15 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBRFRegressor":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
_check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args)
return self
@ -2285,17 +2253,14 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBRFClassifier":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
_check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args)
return self

View File

@ -349,12 +349,6 @@ __model_doc = f"""
See :doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
information.
.. note::
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old
one receives un-transformed prediction regardless of whether custom
objective is being used.
.. code-block:: python
from sklearn.datasets import load_diabetes
@ -389,10 +383,6 @@ __model_doc = f"""
early stopping. If there's more than one metric in **eval_metric**, the last
metric will be used for early stopping.
.. note::
This parameter replaces `early_stopping_rounds` in :py:meth:`fit` method.
callbacks : Optional[List[TrainingCallback]]
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using
@ -872,16 +862,11 @@ class XGBModel(XGBModelBase):
def _configure_fit(
self,
booster: Optional[Union[Booster, "XGBModel", str]],
eval_metric: Optional[Union[Callable, str, Sequence[str]]],
params: Dict[str, Any],
early_stopping_rounds: Optional[int],
callbacks: Optional[Sequence[TrainingCallback]],
) -> Tuple[
Optional[Union[Booster, str, "XGBModel"]],
Optional[Metric],
Dict[str, Any],
Optional[int],
Optional[Sequence[TrainingCallback]],
]:
"""Configure parameters for :py:meth:`fit`."""
if isinstance(booster, XGBModel):
@ -903,49 +888,16 @@ class XGBModel(XGBModelBase):
"or `set_params` instead."
)
# Configure evaluation metric.
if eval_metric is not None:
_deprecated("eval_metric")
if self.eval_metric is not None and eval_metric is not None:
_duplicated("eval_metric")
# - track where does the evaluation metric come from
if self.eval_metric is not None:
from_fit = False
eval_metric = self.eval_metric
else:
from_fit = True
# - configure callable evaluation metric
metric: Optional[Metric] = None
if eval_metric is not None:
if callable(eval_metric) and from_fit:
# No need to wrap the evaluation function for old parameter.
metric = eval_metric
elif callable(eval_metric):
# Parameter from constructor or set_params
if self.eval_metric is not None:
if callable(self.eval_metric):
if self._get_type() == "ranker":
metric = ltr_metric_decorator(eval_metric, self.n_jobs)
metric = ltr_metric_decorator(self.eval_metric, self.n_jobs)
else:
metric = _metric_decorator(eval_metric)
metric = _metric_decorator(self.eval_metric)
else:
params.update({"eval_metric": eval_metric})
# Configure early_stopping_rounds
if early_stopping_rounds is not None:
_deprecated("early_stopping_rounds")
if early_stopping_rounds is not None and self.early_stopping_rounds is not None:
_duplicated("early_stopping_rounds")
early_stopping_rounds = (
self.early_stopping_rounds
if self.early_stopping_rounds is not None
else early_stopping_rounds
)
# Configure callbacks
if callbacks is not None:
_deprecated("callbacks")
if callbacks is not None and self.callbacks is not None:
_duplicated("callbacks")
callbacks = self.callbacks if self.callbacks is not None else callbacks
params.update({"eval_metric": self.eval_metric})
tree_method = params.get("tree_method", None)
if self.enable_categorical and tree_method == "exact":
@ -953,7 +905,7 @@ class XGBModel(XGBModelBase):
"Experimental support for categorical data is not implemented for"
" current tree method yet."
)
return model, metric, params, early_stopping_rounds, callbacks
return model, metric, params
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory.
@ -979,14 +931,11 @@ class XGBModel(XGBModelBase):
sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBModel":
# pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model.
@ -1017,18 +966,6 @@ class XGBModel(XGBModelBase):
metrics will be computed.
Validation metrics will help us track the performance of the model.
eval_metric : str, list of str, or callable, optional
.. deprecated:: 1.6.0
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
early_stopping_rounds : int
.. deprecated:: 1.6.0
Use `early_stopping_rounds` in :py:meth:`__init__` or :py:meth:`set_params`
instead.
verbose :
If `verbose` is True and an evaluation set is used, the evaluation metric
measured on the validation set is printed to stdout at each boosting stage.
@ -1049,10 +986,6 @@ class XGBModel(XGBModelBase):
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown.
callbacks :
.. deprecated:: 1.6.0
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
"""
with config_context(verbosity=self.verbosity):
evals_result: TrainingCallback.EvalsLog = {}
@ -1082,27 +1015,19 @@ class XGBModel(XGBModelBase):
else:
obj = None
(
model,
metric,
params,
early_stopping_rounds,
callbacks,
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
model, metric, params = self._configure_fit(xgb_model, params)
self._Booster = train(
params,
train_dmatrix,
self.get_num_boosting_rounds(),
evals=evals,
early_stopping_rounds=early_stopping_rounds,
early_stopping_rounds=self.early_stopping_rounds,
evals_result=evals_result,
obj=obj,
custom_metric=metric,
verbose_eval=verbose,
xgb_model=model,
callbacks=callbacks,
callbacks=self.callbacks,
)
self._set_evaluation_result(evals_result)
@ -1437,14 +1362,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBClassifier":
# pylint: disable = attribute-defined-outside-init,too-many-statements
with config_context(verbosity=self.verbosity):
@ -1492,15 +1414,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
params["objective"] = "multi:softprob"
params["num_class"] = self.n_classes_
(
model,
metric,
params,
early_stopping_rounds,
callbacks,
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
model, metric, params = self._configure_fit(xgb_model, params)
train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing,
X=X,
@ -1525,13 +1439,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
train_dmatrix,
self.get_num_boosting_rounds(),
evals=evals,
early_stopping_rounds=early_stopping_rounds,
early_stopping_rounds=self.early_stopping_rounds,
evals_result=evals_result,
obj=obj,
custom_metric=metric,
verbose_eval=verbose,
xgb_model=model,
callbacks=callbacks,
callbacks=self.callbacks,
)
if not callable(self.objective):
@ -1693,17 +1607,14 @@ class XGBRFClassifier(XGBClassifier):
sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBRFClassifier":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
_check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args)
return self
@ -1768,17 +1679,14 @@ class XGBRFRegressor(XGBRegressor):
sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBRFRegressor":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
_check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args)
return self
@ -1883,14 +1791,11 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_group: Optional[Sequence[ArrayLike]] = None,
eval_qid: Optional[Sequence[ArrayLike]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = False,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBRanker":
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker
@ -1960,15 +1865,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
pair in **eval_set**. The special column convention in `X` applies to
validation datasets as well.
eval_metric : str, list of str, optional
.. deprecated:: 1.6.0
use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
early_stopping_rounds : int
.. deprecated:: 1.6.0
use `early_stopping_rounds` in :py:meth:`__init__` or
:py:meth:`set_params` instead.
verbose :
If `verbose` is True and an evaluation set is used, the evaluation metric
measured on the validation set is printed to stdout at each boosting stage.
@ -1996,10 +1892,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown.
callbacks :
.. deprecated:: 1.6.0
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
"""
with config_context(verbosity=self.verbosity):
train_dmatrix, evals = _wrap_evaluation_matrices(
@ -2024,27 +1916,19 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
evals_result: TrainingCallback.EvalsLog = {}
params = self.get_xgb_params()
(
model,
metric,
params,
early_stopping_rounds,
callbacks,
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
model, metric, params = self._configure_fit(xgb_model, params)
self._Booster = train(
params,
train_dmatrix,
num_boost_round=self.get_num_boosting_rounds(),
early_stopping_rounds=early_stopping_rounds,
early_stopping_rounds=self.early_stopping_rounds,
evals=evals,
evals_result=evals_result,
custom_metric=metric,
verbose_eval=verbose,
xgb_model=model,
callbacks=callbacks,
callbacks=self.callbacks,
)
self.objective = params["objective"]

View File

@ -18,10 +18,12 @@ class LintersPaths:
"python-package/",
# tests
"tests/python/test_config.py",
"tests/python/test_callback.py",
"tests/python/test_data_iterator.py",
"tests/python/test_dmatrix.py",
"tests/python/test_dt.py",
"tests/python/test_demos.py",
"tests/python/test_eval_metrics.py",
"tests/python/test_multi_target.py",
"tests/python/test_predict.py",
"tests/python/test_quantile_dmatrix.py",
@ -39,12 +41,15 @@ class LintersPaths:
"demo/dask/",
"demo/rmm_plugin",
"demo/json-model/json_parser.py",
"demo/guide-python/continuation.py",
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/callbacks.py",
"demo/guide-python/categorical.py",
"demo/guide-python/cat_pipeline.py",
"demo/guide-python/feature_weights.py",
"demo/guide-python/sklearn_parallel.py",
"demo/guide-python/sklearn_examples.py",
"demo/guide-python/sklearn_evals_result.py",
"demo/guide-python/spark_estimator_examples.py",
"demo/guide-python/external_memory.py",
"demo/guide-python/individual_trees.py",
@ -93,6 +98,7 @@ class LintersPaths:
# demo
"demo/json-model/json_parser.py",
"demo/guide-python/external_memory.py",
"demo/guide-python/continuation.py",
"demo/guide-python/callbacks.py",
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/categorical.py",

View File

@ -16,13 +16,14 @@ class TestCallbacks:
@classmethod
def setup_class(cls):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
cls.X = X
cls.y = y
split = int(X.shape[0]*0.8)
cls.X_train = X[: split, ...]
cls.y_train = y[: split, ...]
split = int(X.shape[0] * 0.8)
cls.X_train = X[:split, ...]
cls.y_train = y[:split, ...]
cls.X_valid = X[split:, ...]
cls.y_valid = y[split:, ...]
@ -31,31 +32,32 @@ class TestCallbacks:
D_train: xgb.DMatrix,
D_valid: xgb.DMatrix,
rounds: int,
verbose_eval: Union[bool, int]
verbose_eval: Union[bool, int],
):
def check_output(output: str) -> None:
if int(verbose_eval) == 1:
# Should print each iteration info
assert len(output.split('\n')) == rounds
assert len(output.split("\n")) == rounds
elif int(verbose_eval) > rounds:
# Should print first and latest iteration info
assert len(output.split('\n')) == 2
assert len(output.split("\n")) == 2
else:
# Should print info by each period additionaly to first and latest
# iteration
num_periods = rounds // int(verbose_eval)
# Extra information is required for latest iteration
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
assert len(output.split('\n')) == (
assert len(output.split("\n")) == (
1 + num_periods + int(is_extra_info_required)
)
evals_result: xgb.callback.TrainingCallback.EvalsLog = {}
params = {'objective': 'binary:logistic', 'eval_metric': 'error'}
params = {"objective": "binary:logistic", "eval_metric": "error"}
with tm.captured_output() as (out, err):
xgb.train(
params, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
params,
D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=rounds,
evals_result=evals_result,
verbose_eval=verbose_eval,
@ -73,14 +75,16 @@ class TestCallbacks:
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
evals_result = {}
rounds = 10
xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
xgb.train(
{"objective": "binary:logistic", "eval_metric": "error"},
D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=rounds,
evals_result=evals_result,
verbose_eval=True)
assert len(evals_result['Train']['error']) == rounds
assert len(evals_result['Valid']['error']) == rounds
verbose_eval=True,
)
assert len(evals_result["Train"]["error"]) == rounds
assert len(evals_result["Valid"]["error"]) == rounds
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
@ -93,72 +97,83 @@ class TestCallbacks:
evals_result = {}
rounds = 30
early_stopping_rounds = 5
booster = xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
booster = xgb.train(
{"objective": "binary:logistic", "eval_metric": "error"},
D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=rounds,
evals_result=evals_result,
verbose_eval=True,
early_stopping_rounds=early_stopping_rounds)
dump = booster.get_dump(dump_format='json')
early_stopping_rounds=early_stopping_rounds,
)
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_early_stopping_custom_eval(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
early_stopping_rounds = 5
booster = xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error',
'tree_method': 'hist'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
booster = xgb.train(
{
"objective": "binary:logistic",
"eval_metric": "error",
"tree_method": "hist",
},
D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
feval=tm.eval_error_metric,
num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=False)
dump = booster.get_dump(dump_format='json')
verbose_eval=False,
)
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_early_stopping_customize(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
metric_name='CustomErr',
data_name='Train')
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, metric_name="CustomErr", data_name="Train"
)
# Specify which dataset and which metric should be used for early stopping.
booster = xgb.train(
{'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'hist'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
{
"objective": "binary:logistic",
"eval_metric": ["error", "rmse"],
"tree_method": "hist",
},
D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
feval=tm.eval_error_metric,
num_boost_round=1000,
callbacks=[early_stop],
verbose_eval=False)
dump = booster.get_dump(dump_format='json')
verbose_eval=False,
)
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump)
assert len(early_stop.stopping_history["Train"]["CustomErr"]) == len(dump)
rounds = 100
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds,
metric_name='CustomErr',
data_name='Train',
metric_name="CustomErr",
data_name="Train",
min_delta=100,
save_best=True,
)
booster = xgb.train(
{
'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'hist'
"objective": "binary:logistic",
"eval_metric": ["error", "rmse"],
"tree_method": "hist",
},
D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
evals=[(D_train, "Train"), (D_valid, "Valid")],
feval=tm.eval_error_metric,
num_boost_round=rounds,
callbacks=[early_stop],
verbose_eval=False
verbose_eval=False,
)
# No iteration can be made with min_delta == 100
assert booster.best_iteration == 0
@ -166,18 +181,20 @@ class TestCallbacks:
def test_early_stopping_skl(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
early_stopping_rounds = 5
cls = xgb.XGBClassifier(
early_stopping_rounds=early_stopping_rounds, eval_metric='error'
early_stopping_rounds=early_stopping_rounds, eval_metric="error"
)
cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_early_stopping_custom_eval_skl(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds)
@ -186,11 +203,12 @@ class TestCallbacks:
)
cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_early_stopping_save_best_model(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
n_estimators = 100
early_stopping_rounds = 5
@ -200,11 +218,11 @@ class TestCallbacks:
cls = xgb.XGBClassifier(
n_estimators=n_estimators,
eval_metric=tm.eval_error_metric_skl,
callbacks=[early_stop]
callbacks=[early_stop],
)
cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
dump = booster.get_dump(dump_format="json")
assert len(dump) == booster.best_iteration + 1
early_stop = xgb.callback.EarlyStopping(
@ -220,8 +238,9 @@ class TestCallbacks:
cls.fit(X, y, eval_set=[(X, y)])
# No error
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=False)
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=False
)
xgb.XGBClassifier(
booster="gblinear",
n_estimators=10,
@ -231,14 +250,17 @@ class TestCallbacks:
def test_early_stopping_continuation(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
cls = xgb.XGBClassifier(eval_metric=tm.eval_error_metric_skl)
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
with pytest.warns(UserWarning):
cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
cls = xgb.XGBClassifier(
eval_metric=tm.eval_error_metric_skl, callbacks=[early_stop]
)
cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster()
assert booster.num_boosted_rounds() == booster.best_iteration + 1
@ -256,21 +278,10 @@ class TestCallbacks:
)
cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster()
assert booster.num_boosted_rounds() == \
booster.best_iteration + early_stopping_rounds + 1
def test_deprecated(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
assert (
booster.num_boosted_rounds()
== booster.best_iteration + early_stopping_rounds + 1
)
clf = xgb.XGBClassifier(
eval_metric=tm.eval_error_metric_skl, callbacks=[early_stop]
)
with pytest.raises(ValueError, match=r".*set_params.*"):
clf.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
def run_eta_decay(self, tree_method):
"""Test learning rate scheduler, used by both CPU and GPU tests."""
@ -343,7 +354,7 @@ class TestCallbacks:
callbacks=[scheduler([0, 0, 0, 0])],
evals_result=evals_result,
)
eval_errors_2 = list(map(float, evals_result['eval']['error']))
eval_errors_2 = list(map(float, evals_result["eval"]["error"]))
assert isinstance(bst, xgb.core.Booster)
# validation error should not decrease, if eta/learning_rate = 0
assert eval_errors_2[0] == eval_errors_2[-1]
@ -361,7 +372,7 @@ class TestCallbacks:
callbacks=[scheduler(eta_decay)],
evals_result=evals_result,
)
eval_errors_3 = list(map(float, evals_result['eval']['error']))
eval_errors_3 = list(map(float, evals_result["eval"]["error"]))
assert isinstance(bst, xgb.core.Booster)

View File

@ -15,23 +15,23 @@ class TestEarlyStopping:
from sklearn.model_selection import train_test_split
digits = load_digits(n_class=2)
X = digits['data']
y = digits['target']
X = digits["data"]
y = digits["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf1 = xgb.XGBClassifier(learning_rate=0.1)
clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc",
eval_set=[(X_test, y_test)])
clf2 = xgb.XGBClassifier(learning_rate=0.1)
clf2.fit(X_train, y_train, early_stopping_rounds=4, eval_metric="auc",
eval_set=[(X_test, y_test)])
clf1 = xgb.XGBClassifier(
learning_rate=0.1, early_stopping_rounds=5, eval_metric="auc"
)
clf1.fit(X_train, y_train, eval_set=[(X_test, y_test)])
clf2 = xgb.XGBClassifier(
learning_rate=0.1, early_stopping_rounds=4, eval_metric="auc"
)
clf2.fit(X_train, y_train, eval_set=[(X_test, y_test)])
# should be the same
assert clf1.best_score == clf2.best_score
assert clf1.best_score != 1
# check overfit
clf3 = xgb.XGBClassifier(
learning_rate=0.1,
eval_metric="auc",
early_stopping_rounds=10
learning_rate=0.1, eval_metric="auc", early_stopping_rounds=10
)
clf3.fit(X_train, y_train, eval_set=[(X_test, y_test)])
base_score = get_basescore(clf3)
@ -39,9 +39,9 @@ class TestEarlyStopping:
clf3 = xgb.XGBClassifier(
learning_rate=0.1,
base_score=.5,
base_score=0.5,
eval_metric="auc",
early_stopping_rounds=10
early_stopping_rounds=10,
)
clf3.fit(X_train, y_train, eval_set=[(X_test, y_test)])

View File

@ -9,37 +9,41 @@ rng = np.random.RandomState(1337)
class TestEvalMetrics:
xgb_params_01 = {'nthread': 1, 'eval_metric': 'error'}
xgb_params_01 = {"nthread": 1, "eval_metric": "error"}
xgb_params_02 = {'nthread': 1, 'eval_metric': ['error']}
xgb_params_02 = {"nthread": 1, "eval_metric": ["error"]}
xgb_params_03 = {'nthread': 1, 'eval_metric': ['rmse', 'error']}
xgb_params_03 = {"nthread": 1, "eval_metric": ["rmse", "error"]}
xgb_params_04 = {'nthread': 1, 'eval_metric': ['error', 'rmse']}
xgb_params_04 = {"nthread": 1, "eval_metric": ["error", "rmse"]}
def evalerror_01(self, preds, dtrain):
labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
return "error", float(sum(labels != (preds > 0.0))) / len(labels)
def evalerror_02(self, preds, dtrain):
labels = dtrain.get_label()
return [('error', float(sum(labels != (preds > 0.0))) / len(labels))]
return [("error", float(sum(labels != (preds > 0.0))) / len(labels))]
@pytest.mark.skipif(**tm.no_sklearn())
def evalerror_03(self, preds, dtrain):
from sklearn.metrics import mean_squared_error
labels = dtrain.get_label()
return [('rmse', mean_squared_error(labels, preds)),
('error', float(sum(labels != (preds > 0.0))) / len(labels))]
return [
("rmse", mean_squared_error(labels, preds)),
("error", float(sum(labels != (preds > 0.0))) / len(labels)),
]
@pytest.mark.skipif(**tm.no_sklearn())
def evalerror_04(self, preds, dtrain):
from sklearn.metrics import mean_squared_error
labels = dtrain.get_label()
return [('error', float(sum(labels != (preds > 0.0))) / len(labels)),
('rmse', mean_squared_error(labels, preds))]
return [
("error", float(sum(labels != (preds > 0.0))) / len(labels)),
("rmse", mean_squared_error(labels, preds)),
]
@pytest.mark.skipif(**tm.no_sklearn())
def test_eval_metrics(self):
@ -50,15 +54,15 @@ class TestEvalMetrics:
from sklearn.datasets import load_digits
digits = load_digits(n_class=2)
X = digits['data']
y = digits['target']
X = digits["data"]
y = digits["target"]
Xt, Xv, yt, yv = train_test_split(X, y, test_size=0.2, random_state=0)
dtrain = xgb.DMatrix(Xt, label=yt)
dvalid = xgb.DMatrix(Xv, label=yv)
watchlist = [(dtrain, 'train'), (dvalid, 'val')]
watchlist = [(dtrain, "train"), (dvalid, "val")]
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10)
gbdt_02 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=10)
@ -66,26 +70,54 @@ class TestEvalMetrics:
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, 10, watchlist,
early_stopping_rounds=2)
gbdt_02 = xgb.train(self.xgb_params_02, dtrain, 10, watchlist,
early_stopping_rounds=2)
gbdt_03 = xgb.train(self.xgb_params_03, dtrain, 10, watchlist,
early_stopping_rounds=2)
gbdt_04 = xgb.train(self.xgb_params_04, dtrain, 10, watchlist,
early_stopping_rounds=2)
gbdt_01 = xgb.train(
self.xgb_params_01, dtrain, 10, watchlist, early_stopping_rounds=2
)
gbdt_02 = xgb.train(
self.xgb_params_02, dtrain, 10, watchlist, early_stopping_rounds=2
)
gbdt_03 = xgb.train(
self.xgb_params_03, dtrain, 10, watchlist, early_stopping_rounds=2
)
gbdt_04 = xgb.train(
self.xgb_params_04, dtrain, 10, watchlist, early_stopping_rounds=2
)
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, 10, watchlist,
early_stopping_rounds=2, feval=self.evalerror_01)
gbdt_02 = xgb.train(self.xgb_params_02, dtrain, 10, watchlist,
early_stopping_rounds=2, feval=self.evalerror_02)
gbdt_03 = xgb.train(self.xgb_params_03, dtrain, 10, watchlist,
early_stopping_rounds=2, feval=self.evalerror_03)
gbdt_04 = xgb.train(self.xgb_params_04, dtrain, 10, watchlist,
early_stopping_rounds=2, feval=self.evalerror_04)
gbdt_01 = xgb.train(
self.xgb_params_01,
dtrain,
10,
watchlist,
early_stopping_rounds=2,
feval=self.evalerror_01,
)
gbdt_02 = xgb.train(
self.xgb_params_02,
dtrain,
10,
watchlist,
early_stopping_rounds=2,
feval=self.evalerror_02,
)
gbdt_03 = xgb.train(
self.xgb_params_03,
dtrain,
10,
watchlist,
early_stopping_rounds=2,
feval=self.evalerror_03,
)
gbdt_04 = xgb.train(
self.xgb_params_04,
dtrain,
10,
watchlist,
early_stopping_rounds=2,
feval=self.evalerror_04,
)
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]
@ -93,6 +125,7 @@ class TestEvalMetrics:
@pytest.mark.skipif(**tm.no_sklearn())
def test_gamma_deviance(self):
from sklearn.metrics import mean_gamma_deviance
rng = np.random.RandomState(1994)
n_samples = 100
n_features = 30
@ -101,8 +134,13 @@ class TestEvalMetrics:
y = rng.randn(n_samples)
y = y - y.min() * 100
reg = xgb.XGBRegressor(tree_method="hist", objective="reg:gamma", n_estimators=10)
reg.fit(X, y, eval_metric="gamma-deviance")
reg = xgb.XGBRegressor(
tree_method="hist",
objective="reg:gamma",
n_estimators=10,
eval_metric="gamma-deviance",
)
reg.fit(X, y)
booster = reg.get_booster()
score = reg.predict(X)
@ -113,16 +151,26 @@ class TestEvalMetrics:
@pytest.mark.skipif(**tm.no_sklearn())
def test_gamma_lik(self) -> None:
import scipy.stats as stats
rng = np.random.default_rng(1994)
n_samples = 32
n_features = 10
X = rng.normal(0, 1, size=n_samples * n_features).reshape((n_samples, n_features))
X = rng.normal(0, 1, size=n_samples * n_features).reshape(
(n_samples, n_features)
)
alpha, loc, beta = 5.0, 11.1, 22
y = stats.gamma.rvs(alpha, loc=loc, scale=beta, size=n_samples, random_state=rng)
reg = xgb.XGBRegressor(tree_method="hist", objective="reg:gamma", n_estimators=64)
reg.fit(X, y, eval_metric="gamma-nloglik", eval_set=[(X, y)])
y = stats.gamma.rvs(
alpha, loc=loc, scale=beta, size=n_samples, random_state=rng
)
reg = xgb.XGBRegressor(
tree_method="hist",
objective="reg:gamma",
n_estimators=64,
eval_metric="gamma-nloglik",
)
reg.fit(X, y, eval_set=[(X, y)])
score = reg.predict(X)
@ -134,7 +182,7 @@ class TestEvalMetrics:
# XGBoost uses the canonical link function of gamma in evaluation function.
# so \theta = - (1.0 / y)
# dispersion is hardcoded as 1.0, so shape (a in scipy parameter) is also 1.0
beta = - (1.0 / (- (1.0 / y))) # == y
beta = -(1.0 / (-(1.0 / y))) # == y
nloglik_stats = -stats.gamma.logpdf(score, a=1.0, scale=beta)
np.testing.assert_allclose(nloglik, np.mean(nloglik_stats), rtol=1e-3)
@ -153,7 +201,7 @@ class TestEvalMetrics:
n_features,
n_informative=n_features,
n_redundant=0,
random_state=rng
random_state=rng,
)
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
@ -197,7 +245,7 @@ class TestEvalMetrics:
n_informative=n_features,
n_redundant=0,
n_classes=n_classes,
random_state=rng
random_state=rng,
)
if weighted:
weights = rng.randn(n_samples)
@ -242,20 +290,25 @@ class TestEvalMetrics:
def run_pr_auc_binary(self, tree_method):
from sklearn.datasets import make_classification
from sklearn.metrics import auc, precision_recall_curve
X, y = make_classification(128, 4, n_classes=2, random_state=1994)
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1)
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)])
clf = xgb.XGBClassifier(
tree_method=tree_method, n_estimators=1, eval_metric="aucpr"
)
clf.fit(X, y, eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
y_score = clf.predict_proba(X)[:, 1] # get the positive column
precision, recall, _ = precision_recall_curve(y, y_score)
prauc = auc(recall, precision)
# Interpolation results are slightly different from sklearn, but overall should be
# similar.
# Interpolation results are slightly different from sklearn, but overall should
# be similar.
np.testing.assert_allclose(prauc, evals_result, rtol=1e-2)
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=10)
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)])
clf = xgb.XGBClassifier(
tree_method=tree_method, n_estimators=10, eval_metric="aucpr"
)
clf.fit(X, y, eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
np.testing.assert_allclose(0.99, evals_result, rtol=1e-2)
@ -264,16 +317,21 @@ class TestEvalMetrics:
def run_pr_auc_multi(self, tree_method):
from sklearn.datasets import make_classification
X, y = make_classification(
64, 16, n_informative=8, n_classes=3, random_state=1994
)
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1)
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)])
clf = xgb.XGBClassifier(
tree_method=tree_method, n_estimators=1, eval_metric="aucpr"
)
clf.fit(X, y, eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
# No available implementation for comparison, just check that XGBoost converges to
# 1.0
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=10)
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)])
# No available implementation for comparison, just check that XGBoost converges
# to 1.0
clf = xgb.XGBClassifier(
tree_method=tree_method, n_estimators=10, eval_metric="aucpr"
)
clf.fit(X, y, eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
np.testing.assert_allclose(1.0, evals_result, rtol=1e-2)
@ -282,9 +340,13 @@ class TestEvalMetrics:
def run_pr_auc_ltr(self, tree_method):
from sklearn.datasets import make_classification
X, y = make_classification(128, 4, n_classes=2, random_state=1994)
ltr = xgb.XGBRanker(
tree_method=tree_method, n_estimators=16, objective="rank:pairwise"
tree_method=tree_method,
n_estimators=16,
objective="rank:pairwise",
eval_metric="aucpr",
)
groups = np.array([32, 32, 64])
ltr.fit(
@ -293,7 +355,6 @@ class TestEvalMetrics:
group=groups,
eval_set=[(X, y)],
eval_group=[groups],
eval_metric="aucpr",
)
results = ltr.evals_result()["validation_0"]["aucpr"]
assert results[-1] >= 0.99

View File

@ -149,8 +149,8 @@ class TestTrainingContinuation:
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
clf = xgb.XGBClassifier(n_estimators=2)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
clf = xgb.XGBClassifier(n_estimators=2, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)])
assert tm.non_increasing(clf.evals_result()["validation_0"]["logloss"])
with tempfile.TemporaryDirectory() as tmpdir:
@ -160,5 +160,6 @@ class TestTrainingContinuation:
clf = xgb.XGBClassifier(n_estimators=2)
# change metric to error
clf.fit(X, y, eval_set=[(X, y)], eval_metric="error")
clf.set_params(eval_metric="error")
clf.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)
assert tm.non_increasing(clf.evals_result()["validation_0"]["error"])

View File

@ -30,8 +30,8 @@ def test_binary_classification():
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
for train_index, test_index in kf.split(X, y):
clf = cls(random_state=42)
xgb_model = clf.fit(X[train_index], y[train_index], eval_metric=['auc', 'logloss'])
clf = cls(random_state=42, eval_metric=['auc', 'logloss'])
xgb_model = clf.fit(X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds))
@ -101,10 +101,11 @@ def test_best_iteration():
def train(booster: str, forest: Optional[int]) -> None:
rounds = 4
cls = xgb.XGBClassifier(
n_estimators=rounds, num_parallel_tree=forest, booster=booster
).fit(
X, y, eval_set=[(X, y)], early_stopping_rounds=3
)
n_estimators=rounds,
num_parallel_tree=forest,
booster=booster,
early_stopping_rounds=3,
).fit(X, y, eval_set=[(X, y)])
assert cls.best_iteration == rounds - 1
# best_iteration is used by default, assert that under gblinear it's
@ -112,9 +113,9 @@ def test_best_iteration():
cls.predict(X)
num_parallel_tree = 4
train('gbtree', num_parallel_tree)
train('dart', num_parallel_tree)
train('gblinear', None)
train("gbtree", num_parallel_tree)
train("dart", num_parallel_tree)
train("gblinear", None)
def test_ranking():
@ -258,6 +259,7 @@ def test_stacking_classification():
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
clf.fit(X_train, y_train).score(X_test, y_test)
@pytest.mark.skipif(**tm.no_pandas())
def test_feature_importances_weight():
from sklearn.datasets import load_digits
@ -474,7 +476,8 @@ def run_housing_rf_regression(tree_method):
rfreg = xgb.XGBRFRegressor()
with pytest.raises(NotImplementedError):
rfreg.fit(X, y, early_stopping_rounds=10)
rfreg.set_params(early_stopping_rounds=10)
rfreg.fit(X, y)
def test_rf_regression():
@ -844,51 +847,65 @@ def run_validation_weights(model):
y_train, y_test = y[:1600], y[1600:]
# instantiate model
param_dist = {'objective': 'binary:logistic', 'n_estimators': 2,
'random_state': 123}
param_dist = {
"objective": "binary:logistic",
"n_estimators": 2,
"random_state": 123,
}
clf = model(**param_dist)
# train it using instance weights only in the training set
weights_train = np.random.choice([1, 2], len(X_train))
clf.fit(X_train, y_train,
clf.set_params(eval_metric="logloss")
clf.fit(
X_train,
y_train,
sample_weight=weights_train,
eval_set=[(X_test, y_test)],
eval_metric='logloss',
verbose=False)
verbose=False,
)
# evaluate logloss metric on test set *without* using weights
evals_result_without_weights = clf.evals_result()
logloss_without_weights = evals_result_without_weights[
"validation_0"]["logloss"]
logloss_without_weights = evals_result_without_weights["validation_0"]["logloss"]
# now use weights for the test set
np.random.seed(0)
weights_test = np.random.choice([1, 2], len(X_test))
clf.fit(X_train, y_train,
clf.set_params(eval_metric="logloss")
clf.fit(
X_train,
y_train,
sample_weight=weights_train,
eval_set=[(X_test, y_test)],
sample_weight_eval_set=[weights_test],
eval_metric='logloss',
verbose=False)
verbose=False,
)
evals_result_with_weights = clf.evals_result()
logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"]
# check that the logloss in the test set is actually different when using
# weights than when not using them
assert all((logloss_with_weights[i] != logloss_without_weights[i]
for i in [0, 1]))
assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1]))
with pytest.raises(ValueError):
# length of eval set and sample weight doesn't match.
clf.fit(X_train, y_train, sample_weight=weights_train,
clf.fit(
X_train,
y_train,
sample_weight=weights_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
sample_weight_eval_set=[weights_train])
sample_weight_eval_set=[weights_train],
)
with pytest.raises(ValueError):
cls = xgb.XGBClassifier()
cls.fit(X_train, y_train, sample_weight=weights_train,
cls.fit(
X_train,
y_train,
sample_weight=weights_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
sample_weight_eval_set=[weights_train])
sample_weight_eval_set=[weights_train],
)
def test_validation_weights():
@ -960,8 +977,7 @@ def test_XGBClassifier_resume():
# file name of stored xgb model
model1.save_model(model1_path)
model2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model2 = xgb.XGBClassifier(learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_path)
pred2 = model2.predict(X)
@ -972,8 +988,7 @@ def test_XGBClassifier_resume():
# file name of 'Booster' instance Xgb model
model1.get_booster().save_model(model1_booster_path)
model2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model2 = xgb.XGBClassifier(learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_booster_path)
pred2 = model2.predict(X)
@ -1279,12 +1294,16 @@ def test_estimator_reg(estimator, check):
):
estimator.fit(X, y)
return
if os.environ["PYTEST_CURRENT_TEST"].find("check_estimators_overwrite_params") != -1:
if (
os.environ["PYTEST_CURRENT_TEST"].find("check_estimators_overwrite_params")
!= -1
):
# A hack to pass the scikit-learn parameter mutation tests. XGBoost regressor
# returns actual internal default values for parameters in `get_params`, but those
# are set as `None` in sklearn interface to avoid duplication. So we fit a dummy
# model and obtain the default parameters here for the mutation tests.
# returns actual internal default values for parameters in `get_params`, but
# those are set as `None` in sklearn interface to avoid duplication. So we fit
# a dummy model and obtain the default parameters here for the mutation tests.
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=2, n_features=1)
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())
@ -1325,6 +1344,7 @@ def test_categorical():
def test_evaluation_metric():
from sklearn.datasets import load_diabetes, load_digits
from sklearn.metrics import mean_absolute_error
X, y = load_diabetes(return_X_y=True)
n_estimators = 16
@ -1341,17 +1361,6 @@ def test_evaluation_metric():
for line in lines:
assert line.find("mean_absolute_error") != -1
def metric(predt: np.ndarray, Xy: xgb.DMatrix):
y = Xy.get_label()
return "m", np.abs(predt - y).sum()
with pytest.warns(UserWarning):
reg = xgb.XGBRegressor(
tree_method="hist",
n_estimators=1,
)
reg.fit(X, y, eval_set=[(X, y)], eval_metric=metric)
def merror(y_true: np.ndarray, predt: np.ndarray):
n_samples = y_true.shape[0]
assert n_samples == predt.size

View File

@ -363,12 +363,12 @@ class TestDistributedGPU:
device="cuda",
eval_metric="error",
n_estimators=100,
early_stopping_rounds=early_stopping_rounds,
)
cls.client = local_cuda_client
cls.fit(
X,
y,
early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)],
)
booster = cls.get_booster()

View File

@ -937,8 +937,10 @@ def run_empty_dmatrix_auc(client: "Client", device: str, n_workers: int) -> None
valid_X = dd.from_array(valid_X_, chunksize=n_samples)
valid_y = dd.from_array(valid_y_, chunksize=n_samples)
cls = xgb.dask.DaskXGBClassifier(device=device, n_estimators=2)
cls.fit(X, y, eval_metric=["auc", "aucpr"], eval_set=[(valid_X, valid_y)])
cls = xgb.dask.DaskXGBClassifier(
device=device, n_estimators=2, eval_metric=["auc", "aucpr"]
)
cls.fit(X, y, eval_set=[(valid_X, valid_y)])
# multiclass
X_, y_ = make_classification(
@ -966,8 +968,10 @@ def run_empty_dmatrix_auc(client: "Client", device: str, n_workers: int) -> None
valid_X = dd.from_array(valid_X_, chunksize=n_samples)
valid_y = dd.from_array(valid_y_, chunksize=n_samples)
cls = xgb.dask.DaskXGBClassifier(device=device, n_estimators=2)
cls.fit(X, y, eval_metric=["auc", "aucpr"], eval_set=[(valid_X, valid_y)])
cls = xgb.dask.DaskXGBClassifier(
device=device, n_estimators=2, eval_metric=["auc", "aucpr"]
)
cls.fit(X, y, eval_set=[(valid_X, valid_y)])
def test_empty_dmatrix_auc() -> None:
@ -994,11 +998,11 @@ def run_auc(client: "Client", device: str) -> None:
valid_X = dd.from_array(valid_X_, chunksize=10)
valid_y = dd.from_array(valid_y_, chunksize=10)
cls = xgb.XGBClassifier(device=device, n_estimators=2)
cls.fit(X_, y_, eval_metric="auc", eval_set=[(valid_X_, valid_y_)])
cls = xgb.XGBClassifier(device=device, n_estimators=2, eval_metric="auc")
cls.fit(X_, y_, eval_set=[(valid_X_, valid_y_)])
dcls = xgb.dask.DaskXGBClassifier(device=device, n_estimators=2)
dcls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)])
dcls = xgb.dask.DaskXGBClassifier(device=device, n_estimators=2, eval_metric="auc")
dcls.fit(X, y, eval_set=[(valid_X, valid_y)])
approx = dcls.evals_result()["validation_0"]["auc"]
exact = cls.evals_result()["validation_0"]["auc"]
@ -1267,16 +1271,16 @@ def test_dask_ranking(client: "Client") -> None:
qid_valid = qid_valid.astype(np.uint32)
qid_test = qid_test.astype(np.uint32)
rank = xgb.dask.DaskXGBRanker(n_estimators=2500)
rank = xgb.dask.DaskXGBRanker(
n_estimators=2500, eval_metric=["ndcg"], early_stopping_rounds=10
)
rank.fit(
x_train,
y_train,
qid=qid_train,
eval_set=[(x_test, y_test), (x_train, y_train)],
eval_qid=[qid_test, qid_train],
eval_metric=["ndcg"],
verbose=True,
early_stopping_rounds=10,
)
assert rank.n_features_in_ == 46
assert rank.best_score > 0.98
@ -2150,13 +2154,15 @@ class TestDaskCallbacks:
valid_X, valid_y = load_breast_cancer(return_X_y=True)
valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y)
cls = xgb.dask.DaskXGBClassifier(
objective="binary:logistic", tree_method="hist", n_estimators=1000
objective="binary:logistic",
tree_method="hist",
n_estimators=1000,
early_stopping_rounds=early_stopping_rounds,
)
cls.client = client
cls.fit(
X,
y,
early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)],
)
booster = cls.get_booster()
@ -2165,15 +2171,17 @@ class TestDaskCallbacks:
# Specify the metric
cls = xgb.dask.DaskXGBClassifier(
objective="binary:logistic", tree_method="hist", n_estimators=1000
objective="binary:logistic",
tree_method="hist",
n_estimators=1000,
early_stopping_rounds=early_stopping_rounds,
eval_metric="error",
)
cls.client = client
cls.fit(
X,
y,
early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)],
eval_metric="error",
)
assert tm.non_increasing(cls.evals_result()["validation_0"]["error"])
booster = cls.get_booster()
@ -2215,12 +2223,12 @@ class TestDaskCallbacks:
tree_method="hist",
n_estimators=1000,
eval_metric=tm.eval_error_metric_skl,
early_stopping_rounds=early_stopping_rounds,
)
cls.client = client
cls.fit(
X,
y,
early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)],
)
booster = cls.get_booster()
@ -2234,21 +2242,22 @@ class TestDaskCallbacks:
X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
cls = xgb.dask.DaskXGBClassifier(
objective="binary:logistic", tree_method="hist", n_estimators=10
)
cls.client = client
with tempfile.TemporaryDirectory() as tmpdir:
cls.fit(
X,
y,
cls = xgb.dask.DaskXGBClassifier(
objective="binary:logistic",
tree_method="hist",
n_estimators=10,
callbacks=[
xgb.callback.TrainingCheckPoint(
directory=Path(tmpdir), interval=1, name="model"
)
],
)
cls.client = client
cls.fit(
X,
y,
)
for i in range(1, 10):
assert os.path.exists(
os.path.join(

View File

@ -311,24 +311,20 @@ def clf_with_weight(
y_val = np.array([0, 1])
w_train = np.array([1.0, 2.0])
w_val = np.array([1.0, 2.0])
cls2 = XGBClassifier()
cls2 = XGBClassifier(eval_metric="logloss", early_stopping_rounds=1)
cls2.fit(
X_train,
y_train,
eval_set=[(X_val, y_val)],
early_stopping_rounds=1,
eval_metric="logloss",
)
cls3 = XGBClassifier()
cls3 = XGBClassifier(eval_metric="logloss", early_stopping_rounds=1)
cls3.fit(
X_train,
y_train,
sample_weight=w_train,
eval_set=[(X_val, y_val)],
sample_weight_eval_set=[w_val],
early_stopping_rounds=1,
eval_metric="logloss",
)
cls_df_train_with_eval_weight = spark.createDataFrame(