[breaking] Remove deprecated parameters in the skl interface. (#9986)

This commit is contained in:
Jiaming Yuan 2024-01-15 20:40:05 +08:00 committed by GitHub
parent 2de85d3241
commit 0798e36d73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 418 additions and 462 deletions

View File

@ -16,14 +16,14 @@ def training_continuation(tmpdir: str, use_pickle: bool) -> None:
"""Basic training continuation.""" """Basic training continuation."""
# Train 128 iterations in 1 session # Train 128 iterations in 1 session
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
clf = xgboost.XGBClassifier(n_estimators=128) clf = xgboost.XGBClassifier(n_estimators=128, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss") clf.fit(X, y, eval_set=[(X, y)])
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds()) print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
# Train 128 iterations in 2 sessions, with the first one runs for 32 iterations and # Train 128 iterations in 2 sessions, with the first one runs for 32 iterations and
# the second one runs for 96 iterations # the second one runs for 96 iterations
clf = xgboost.XGBClassifier(n_estimators=32) clf = xgboost.XGBClassifier(n_estimators=32, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss") clf.fit(X, y, eval_set=[(X, y)])
assert clf.get_booster().num_boosted_rounds() == 32 assert clf.get_booster().num_boosted_rounds() == 32
# load back the model, this could be a checkpoint # load back the model, this could be a checkpoint
@ -39,8 +39,8 @@ def training_continuation(tmpdir: str, use_pickle: bool) -> None:
loaded = xgboost.XGBClassifier() loaded = xgboost.XGBClassifier()
loaded.load_model(path) loaded.load_model(path)
clf = xgboost.XGBClassifier(n_estimators=128 - 32) clf = xgboost.XGBClassifier(n_estimators=128 - 32, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", xgb_model=loaded) clf.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds()) print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
@ -56,19 +56,24 @@ def training_continuation_early_stop(tmpdir: str, use_pickle: bool) -> None:
n_estimators = 512 n_estimators = 512
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
clf = xgboost.XGBClassifier(n_estimators=n_estimators) clf = xgboost.XGBClassifier(
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", callbacks=[early_stop]) n_estimators=n_estimators, eval_metric="logloss", callbacks=[early_stop]
)
clf.fit(X, y, eval_set=[(X, y)])
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds()) print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
best = clf.best_iteration best = clf.best_iteration
# Train 512 iterations in 2 sessions, with the first one runs for 128 iterations and # Train 512 iterations in 2 sessions, with the first one runs for 128 iterations and
# the second one runs until early stop. # the second one runs until early stop.
clf = xgboost.XGBClassifier(n_estimators=128) clf = xgboost.XGBClassifier(
n_estimators=128, eval_metric="logloss", callbacks=[early_stop]
)
# Reinitialize the early stop callback # Reinitialize the early stop callback
early_stop = xgboost.callback.EarlyStopping( early_stop = xgboost.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True rounds=early_stopping_rounds, save_best=True
) )
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", callbacks=[early_stop]) clf.set_params(callbacks=[early_stop])
clf.fit(X, y, eval_set=[(X, y)])
assert clf.get_booster().num_boosted_rounds() == 128 assert clf.get_booster().num_boosted_rounds() == 128
# load back the model, this could be a checkpoint # load back the model, this could be a checkpoint
@ -87,13 +92,13 @@ def training_continuation_early_stop(tmpdir: str, use_pickle: bool) -> None:
early_stop = xgboost.callback.EarlyStopping( early_stop = xgboost.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True rounds=early_stopping_rounds, save_best=True
) )
clf = xgboost.XGBClassifier(n_estimators=n_estimators - 128) clf = xgboost.XGBClassifier(
n_estimators=n_estimators - 128, eval_metric="logloss", callbacks=[early_stop]
)
clf.fit( clf.fit(
X, X,
y, y,
eval_set=[(X, y)], eval_set=[(X, y)],
eval_metric="logloss",
callbacks=[early_stop],
xgb_model=loaded, xgb_model=loaded,
) )

View File

@ -16,30 +16,35 @@ labels, y = np.unique(y, return_inverse=True)
X_train, X_test = X[:1600], X[1600:] X_train, X_test = X[:1600], X[1600:]
y_train, y_test = y[:1600], y[1600:] y_train, y_test = y[:1600], y[1600:]
param_dist = {'objective':'binary:logistic', 'n_estimators':2} param_dist = {"objective": "binary:logistic", "n_estimators": 2}
clf = xgb.XGBModel(**param_dist) clf = xgb.XGBModel(
**param_dist,
eval_metric="logloss",
)
# Or you can use: clf = xgb.XGBClassifier(**param_dist) # Or you can use: clf = xgb.XGBClassifier(**param_dist)
clf.fit(X_train, y_train, clf.fit(
X_train,
y_train,
eval_set=[(X_train, y_train), (X_test, y_test)], eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss', verbose=True,
verbose=True) )
# Load evals result by calling the evals_result() function # Load evals result by calling the evals_result() function
evals_result = clf.evals_result() evals_result = clf.evals_result()
print('Access logloss metric directly from validation_0:') print("Access logloss metric directly from validation_0:")
print(evals_result['validation_0']['logloss']) print(evals_result["validation_0"]["logloss"])
print('') print("")
print('Access metrics through a loop:') print("Access metrics through a loop:")
for e_name, e_mtrs in evals_result.items(): for e_name, e_mtrs in evals_result.items():
print('- {}'.format(e_name)) print("- {}".format(e_name))
for e_mtr_name, e_mtr_vals in e_mtrs.items(): for e_mtr_name, e_mtr_vals in e_mtrs.items():
print(' - {}'.format(e_mtr_name)) print(" - {}".format(e_mtr_name))
print(' - {}'.format(e_mtr_vals)) print(" - {}".format(e_mtr_vals))
print('') print("")
print('Access complete dict:') print("Access complete dict:")
print(evals_result) print(evals_result)

View File

@ -1,4 +1,4 @@
''' """
Collection of examples for using sklearn interface Collection of examples for using sklearn interface
================================================== ==================================================
@ -8,7 +8,7 @@ For an introduction to XGBoost's scikit-learn estimator interface, see
Created on 1 Apr 2015 Created on 1 Apr 2015
@author: Jamie Hall @author: Jamie Hall
''' """
import pickle import pickle
import numpy as np import numpy as np
@ -22,8 +22,8 @@ rng = np.random.RandomState(31337)
print("Zeros and Ones from the Digits dataset: binary classification") print("Zeros and Ones from the Digits dataset: binary classification")
digits = load_digits(n_class=2) digits = load_digits(n_class=2)
y = digits['target'] y = digits["target"]
X = digits['data'] X = digits["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X): for train_index, test_index in kf.split(X):
xgb_model = xgb.XGBClassifier(n_jobs=1).fit(X[train_index], y[train_index]) xgb_model = xgb.XGBClassifier(n_jobs=1).fit(X[train_index], y[train_index])
@ -33,8 +33,8 @@ for train_index, test_index in kf.split(X):
print("Iris: multiclass classification") print("Iris: multiclass classification")
iris = load_iris() iris = load_iris()
y = iris['target'] y = iris["target"]
X = iris['data'] X = iris["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X): for train_index, test_index in kf.split(X):
xgb_model = xgb.XGBClassifier(n_jobs=1).fit(X[train_index], y[train_index]) xgb_model = xgb.XGBClassifier(n_jobs=1).fit(X[train_index], y[train_index])
@ -53,9 +53,13 @@ for train_index, test_index in kf.split(X):
print("Parameter optimization") print("Parameter optimization")
xgb_model = xgb.XGBRegressor(n_jobs=1) xgb_model = xgb.XGBRegressor(n_jobs=1)
clf = GridSearchCV(xgb_model, clf = GridSearchCV(
{'max_depth': [2, 4], xgb_model,
'n_estimators': [50, 100]}, verbose=1, n_jobs=1, cv=3) {"max_depth": [2, 4], "n_estimators": [50, 100]},
verbose=1,
n_jobs=1,
cv=3,
)
clf.fit(X, y) clf.fit(X, y)
print(clf.best_score_) print(clf.best_score_)
print(clf.best_params_) print(clf.best_params_)
@ -69,9 +73,8 @@ print(np.allclose(clf.predict(X), clf2.predict(X)))
# Early-stopping # Early-stopping
X = digits['data'] X = digits["data"]
y = digits['target'] y = digits["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = xgb.XGBClassifier(n_jobs=1) clf = xgb.XGBClassifier(n_jobs=1, early_stopping_rounds=10, eval_metric="auc")
clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc", clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
eval_set=[(X_test, y_test)])

View File

@ -12,6 +12,7 @@ import xgboost as xgb
if __name__ == "__main__": if __name__ == "__main__":
print("Parallel Parameter optimization") print("Parallel Parameter optimization")
X, y = fetch_california_housing(return_X_y=True) X, y = fetch_california_housing(return_X_y=True)
# Make sure the number of threads is balanced.
xgb_model = xgb.XGBRegressor( xgb_model = xgb.XGBRegressor(
n_jobs=multiprocessing.cpu_count() // 2, tree_method="hist" n_jobs=multiprocessing.cpu_count() // 2, tree_method="hist"
) )

View File

@ -123,11 +123,11 @@ monitor our model's performance. As mentioned above, the default metric for ``S
elements = np.power(np.log1p(y) - np.log1p(predt), 2) elements = np.power(np.log1p(y) - np.log1p(predt), 2)
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y))) return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))
Since we are demonstrating in Python, the metric or objective need not be a function, Since we are demonstrating in Python, the metric or objective need not be a function, any
any callable object should suffice. Similar to the objective function, our metric also callable object should suffice. Similar to the objective function, our metric also
accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric itself and a accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric itself and
floating point value as the result. After passing it into XGBoost as argument of ``feval`` a floating point value as the result. After passing it into XGBoost as argument of
parameter: ``custom_metric`` parameter:
.. code-block:: python .. code-block:: python
@ -136,7 +136,7 @@ parameter:
dtrain=dtrain, dtrain=dtrain,
num_boost_round=10, num_boost_round=10,
obj=squared_log, obj=squared_log,
feval=rmsle, custom_metric=rmsle,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')], evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results) evals_result=results)

View File

@ -61,7 +61,7 @@ from typing import (
import numpy import numpy
from xgboost import collective, config from xgboost import collective, config
from xgboost._typing import _T, FeatureNames, FeatureTypes, ModelIn from xgboost._typing import _T, FeatureNames, FeatureTypes
from xgboost.callback import TrainingCallback from xgboost.callback import TrainingCallback
from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance
from xgboost.core import ( from xgboost.core import (
@ -1774,14 +1774,11 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
sample_weight: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]], eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
eval_metric: Optional[Union[str, Sequence[str], Metric]],
sample_weight_eval_set: Optional[Sequence[_DaskCollection]], sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[Sequence[_DaskCollection]], base_margin_eval_set: Optional[Sequence[_DaskCollection]],
early_stopping_rounds: Optional[int],
verbose: Union[int, bool], verbose: Union[int, bool],
xgb_model: Optional[Union[Booster, XGBModel]], xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection], feature_weights: Optional[_DaskCollection],
callbacks: Optional[Sequence[TrainingCallback]],
) -> _DaskCollection: ) -> _DaskCollection:
params = self.get_xgb_params() params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices( dtrain, evals = await _async_wrap_evaluation_matrices(
@ -1809,9 +1806,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
obj: Optional[Callable] = _objective_decorator(self.objective) obj: Optional[Callable] = _objective_decorator(self.objective)
else: else:
obj = None obj = None
model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( model, metric, params = self._configure_fit(xgb_model, params)
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
results = await self.client.sync( results = await self.client.sync(
_train_async, _train_async,
asynchronous=True, asynchronous=True,
@ -1826,8 +1821,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
feval=None, feval=None,
custom_metric=metric, custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=self.early_stopping_rounds,
callbacks=callbacks, callbacks=self.callbacks,
xgb_model=model, xgb_model=model,
) )
self._Booster = results["booster"] self._Booster = results["booster"]
@ -1844,14 +1839,11 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = True, verbose: Union[int, bool] = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None, xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBRegressor": ) -> "DaskXGBRegressor":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -1871,14 +1863,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
sample_weight: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]], eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
eval_metric: Optional[Union[str, Sequence[str], Metric]],
sample_weight_eval_set: Optional[Sequence[_DaskCollection]], sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[Sequence[_DaskCollection]], base_margin_eval_set: Optional[Sequence[_DaskCollection]],
early_stopping_rounds: Optional[int],
verbose: Union[int, bool], verbose: Union[int, bool],
xgb_model: Optional[Union[Booster, XGBModel]], xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection], feature_weights: Optional[_DaskCollection],
callbacks: Optional[Sequence[TrainingCallback]],
) -> "DaskXGBClassifier": ) -> "DaskXGBClassifier":
params = self.get_xgb_params() params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices( dtrain, evals = await _async_wrap_evaluation_matrices(
@ -1924,9 +1913,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
obj: Optional[Callable] = _objective_decorator(self.objective) obj: Optional[Callable] = _objective_decorator(self.objective)
else: else:
obj = None obj = None
model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( model, metric, params = self._configure_fit(xgb_model, params)
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
results = await self.client.sync( results = await self.client.sync(
_train_async, _train_async,
asynchronous=True, asynchronous=True,
@ -1941,8 +1928,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
feval=None, feval=None,
custom_metric=metric, custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=self.early_stopping_rounds,
callbacks=callbacks, callbacks=self.callbacks,
xgb_model=model, xgb_model=model,
) )
self._Booster = results["booster"] self._Booster = results["booster"]
@ -1960,14 +1947,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = True, verbose: Union[int, bool] = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None, xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBClassifier": ) -> "DaskXGBClassifier":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -2063,7 +2047,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
if callable(objective): if callable(objective):
raise ValueError("Custom objective function not supported by XGBRanker.") raise ValueError("Custom objective function not supported by XGBRanker.")
super().__init__(objective=objective, kwargs=kwargs) super().__init__(objective=objective, **kwargs)
async def _fit_async( async def _fit_async(
self, self,
@ -2078,12 +2062,9 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
base_margin_eval_set: Optional[Sequence[_DaskCollection]], base_margin_eval_set: Optional[Sequence[_DaskCollection]],
eval_group: Optional[Sequence[_DaskCollection]], eval_group: Optional[Sequence[_DaskCollection]],
eval_qid: Optional[Sequence[_DaskCollection]], eval_qid: Optional[Sequence[_DaskCollection]],
eval_metric: Optional[Union[str, Sequence[str], Metric]],
early_stopping_rounds: Optional[int],
verbose: Union[int, bool], verbose: Union[int, bool],
xgb_model: Optional[Union[XGBModel, Booster]], xgb_model: Optional[Union[XGBModel, Booster]],
feature_weights: Optional[_DaskCollection], feature_weights: Optional[_DaskCollection],
callbacks: Optional[Sequence[TrainingCallback]],
) -> "DaskXGBRanker": ) -> "DaskXGBRanker":
msg = "Use `qid` instead of `group` on dask interface." msg = "Use `qid` instead of `group` on dask interface."
if not (group is None and eval_group is None): if not (group is None and eval_group is None):
@ -2111,14 +2092,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
enable_categorical=self.enable_categorical, enable_categorical=self.enable_categorical,
feature_types=self.feature_types, feature_types=self.feature_types,
) )
if eval_metric is not None: model, metric, params = self._configure_fit(xgb_model, params)
if callable(eval_metric):
raise ValueError(
"Custom evaluation metric is not yet supported for XGBRanker."
)
model, metric, params, early_stopping_rounds, callbacks = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
results = await self.client.sync( results = await self.client.sync(
_train_async, _train_async,
asynchronous=True, asynchronous=True,
@ -2133,8 +2107,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
feval=None, feval=None,
custom_metric=metric, custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=self.early_stopping_rounds,
callbacks=callbacks, callbacks=self.callbacks,
xgb_model=model, xgb_model=model,
) )
self._Booster = results["booster"] self._Booster = results["booster"]
@ -2155,14 +2129,11 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_group: Optional[Sequence[_DaskCollection]] = None, eval_group: Optional[Sequence[_DaskCollection]] = None,
eval_qid: Optional[Sequence[_DaskCollection]] = None, eval_qid: Optional[Sequence[_DaskCollection]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = False, verbose: Union[int, bool] = False,
xgb_model: Optional[Union[XGBModel, Booster]] = None, xgb_model: Optional[Union[XGBModel, Booster]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBRanker": ) -> "DaskXGBRanker":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -2221,18 +2192,15 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = True, verbose: Union[int, bool] = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None, xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBRFRegressor": ) -> "DaskXGBRFRegressor":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks) _check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args) super().fit(**args)
return self return self
@ -2285,17 +2253,14 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Union[int, bool] = True, verbose: Union[int, bool] = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None, xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBRFClassifier": ) -> "DaskXGBRFClassifier":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks) _check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args) super().fit(**args)
return self return self

View File

@ -349,12 +349,6 @@ __model_doc = f"""
See :doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more See :doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
information. information.
.. note::
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old
one receives un-transformed prediction regardless of whether custom
objective is being used.
.. code-block:: python .. code-block:: python
from sklearn.datasets import load_diabetes from sklearn.datasets import load_diabetes
@ -389,10 +383,6 @@ __model_doc = f"""
early stopping. If there's more than one metric in **eval_metric**, the last early stopping. If there's more than one metric in **eval_metric**, the last
metric will be used for early stopping. metric will be used for early stopping.
.. note::
This parameter replaces `early_stopping_rounds` in :py:meth:`fit` method.
callbacks : Optional[List[TrainingCallback]] callbacks : Optional[List[TrainingCallback]]
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using It is possible to use predefined callbacks by using
@ -872,16 +862,11 @@ class XGBModel(XGBModelBase):
def _configure_fit( def _configure_fit(
self, self,
booster: Optional[Union[Booster, "XGBModel", str]], booster: Optional[Union[Booster, "XGBModel", str]],
eval_metric: Optional[Union[Callable, str, Sequence[str]]],
params: Dict[str, Any], params: Dict[str, Any],
early_stopping_rounds: Optional[int],
callbacks: Optional[Sequence[TrainingCallback]],
) -> Tuple[ ) -> Tuple[
Optional[Union[Booster, str, "XGBModel"]], Optional[Union[Booster, str, "XGBModel"]],
Optional[Metric], Optional[Metric],
Dict[str, Any], Dict[str, Any],
Optional[int],
Optional[Sequence[TrainingCallback]],
]: ]:
"""Configure parameters for :py:meth:`fit`.""" """Configure parameters for :py:meth:`fit`."""
if isinstance(booster, XGBModel): if isinstance(booster, XGBModel):
@ -903,49 +888,16 @@ class XGBModel(XGBModelBase):
"or `set_params` instead." "or `set_params` instead."
) )
# Configure evaluation metric.
if eval_metric is not None:
_deprecated("eval_metric")
if self.eval_metric is not None and eval_metric is not None:
_duplicated("eval_metric")
# - track where does the evaluation metric come from
if self.eval_metric is not None:
from_fit = False
eval_metric = self.eval_metric
else:
from_fit = True
# - configure callable evaluation metric # - configure callable evaluation metric
metric: Optional[Metric] = None metric: Optional[Metric] = None
if eval_metric is not None: if self.eval_metric is not None:
if callable(eval_metric) and from_fit: if callable(self.eval_metric):
# No need to wrap the evaluation function for old parameter.
metric = eval_metric
elif callable(eval_metric):
# Parameter from constructor or set_params
if self._get_type() == "ranker": if self._get_type() == "ranker":
metric = ltr_metric_decorator(eval_metric, self.n_jobs) metric = ltr_metric_decorator(self.eval_metric, self.n_jobs)
else: else:
metric = _metric_decorator(eval_metric) metric = _metric_decorator(self.eval_metric)
else: else:
params.update({"eval_metric": eval_metric}) params.update({"eval_metric": self.eval_metric})
# Configure early_stopping_rounds
if early_stopping_rounds is not None:
_deprecated("early_stopping_rounds")
if early_stopping_rounds is not None and self.early_stopping_rounds is not None:
_duplicated("early_stopping_rounds")
early_stopping_rounds = (
self.early_stopping_rounds
if self.early_stopping_rounds is not None
else early_stopping_rounds
)
# Configure callbacks
if callbacks is not None:
_deprecated("callbacks")
if callbacks is not None and self.callbacks is not None:
_duplicated("callbacks")
callbacks = self.callbacks if self.callbacks is not None else callbacks
tree_method = params.get("tree_method", None) tree_method = params.get("tree_method", None)
if self.enable_categorical and tree_method == "exact": if self.enable_categorical and tree_method == "exact":
@ -953,7 +905,7 @@ class XGBModel(XGBModelBase):
"Experimental support for categorical data is not implemented for" "Experimental support for categorical data is not implemented for"
" current tree method yet." " current tree method yet."
) )
return model, metric, params, early_stopping_rounds, callbacks return model, metric, params
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix: def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory. # Use `QuantileDMatrix` to save memory.
@ -979,14 +931,11 @@ class XGBModel(XGBModelBase):
sample_weight: Optional[ArrayLike] = None, sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None, base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True, verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None, xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None, feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBModel": ) -> "XGBModel":
# pylint: disable=invalid-name,attribute-defined-outside-init # pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model. """Fit gradient boosting model.
@ -1017,18 +966,6 @@ class XGBModel(XGBModelBase):
metrics will be computed. metrics will be computed.
Validation metrics will help us track the performance of the model. Validation metrics will help us track the performance of the model.
eval_metric : str, list of str, or callable, optional
.. deprecated:: 1.6.0
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
early_stopping_rounds : int
.. deprecated:: 1.6.0
Use `early_stopping_rounds` in :py:meth:`__init__` or :py:meth:`set_params`
instead.
verbose : verbose :
If `verbose` is True and an evaluation set is used, the evaluation metric If `verbose` is True and an evaluation set is used, the evaluation metric
measured on the validation set is printed to stdout at each boosting stage. measured on the validation set is printed to stdout at each boosting stage.
@ -1049,10 +986,6 @@ class XGBModel(XGBModelBase):
selected when colsample is being used. All values must be greater than 0, selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. otherwise a `ValueError` is thrown.
callbacks :
.. deprecated:: 1.6.0
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
""" """
with config_context(verbosity=self.verbosity): with config_context(verbosity=self.verbosity):
evals_result: TrainingCallback.EvalsLog = {} evals_result: TrainingCallback.EvalsLog = {}
@ -1082,27 +1015,19 @@ class XGBModel(XGBModelBase):
else: else:
obj = None obj = None
( model, metric, params = self._configure_fit(xgb_model, params)
model,
metric,
params,
early_stopping_rounds,
callbacks,
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
self._Booster = train( self._Booster = train(
params, params,
train_dmatrix, train_dmatrix,
self.get_num_boosting_rounds(), self.get_num_boosting_rounds(),
evals=evals, evals=evals,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=self.early_stopping_rounds,
evals_result=evals_result, evals_result=evals_result,
obj=obj, obj=obj,
custom_metric=metric, custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
xgb_model=model, xgb_model=model,
callbacks=callbacks, callbacks=self.callbacks,
) )
self._set_evaluation_result(evals_result) self._set_evaluation_result(evals_result)
@ -1437,14 +1362,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
sample_weight: Optional[ArrayLike] = None, sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None, base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True, verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None, xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None, feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBClassifier": ) -> "XGBClassifier":
# pylint: disable = attribute-defined-outside-init,too-many-statements # pylint: disable = attribute-defined-outside-init,too-many-statements
with config_context(verbosity=self.verbosity): with config_context(verbosity=self.verbosity):
@ -1492,15 +1414,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
params["objective"] = "multi:softprob" params["objective"] = "multi:softprob"
params["num_class"] = self.n_classes_ params["num_class"] = self.n_classes_
( model, metric, params = self._configure_fit(xgb_model, params)
model,
metric,
params,
early_stopping_rounds,
callbacks,
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
train_dmatrix, evals = _wrap_evaluation_matrices( train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing, missing=self.missing,
X=X, X=X,
@ -1525,13 +1439,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
train_dmatrix, train_dmatrix,
self.get_num_boosting_rounds(), self.get_num_boosting_rounds(),
evals=evals, evals=evals,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=self.early_stopping_rounds,
evals_result=evals_result, evals_result=evals_result,
obj=obj, obj=obj,
custom_metric=metric, custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
xgb_model=model, xgb_model=model,
callbacks=callbacks, callbacks=self.callbacks,
) )
if not callable(self.objective): if not callable(self.objective):
@ -1693,17 +1607,14 @@ class XGBRFClassifier(XGBClassifier):
sample_weight: Optional[ArrayLike] = None, sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None, base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True, verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None, xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None, feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBRFClassifier": ) -> "XGBRFClassifier":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks) _check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args) super().fit(**args)
return self return self
@ -1768,17 +1679,14 @@ class XGBRFRegressor(XGBRegressor):
sample_weight: Optional[ArrayLike] = None, sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None, base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True, verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None, xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None, feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBRFRegressor": ) -> "XGBRFRegressor":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks) _check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args) super().fit(**args)
return self return self
@ -1883,14 +1791,11 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_group: Optional[Sequence[ArrayLike]] = None, eval_group: Optional[Sequence[ArrayLike]] = None,
eval_qid: Optional[Sequence[ArrayLike]] = None, eval_qid: Optional[Sequence[ArrayLike]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = False, verbose: Optional[Union[bool, int]] = False,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None, xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None, feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBRanker": ) -> "XGBRanker":
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker """Fit gradient boosting ranker
@ -1960,15 +1865,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
pair in **eval_set**. The special column convention in `X` applies to pair in **eval_set**. The special column convention in `X` applies to
validation datasets as well. validation datasets as well.
eval_metric : str, list of str, optional
.. deprecated:: 1.6.0
use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
early_stopping_rounds : int
.. deprecated:: 1.6.0
use `early_stopping_rounds` in :py:meth:`__init__` or
:py:meth:`set_params` instead.
verbose : verbose :
If `verbose` is True and an evaluation set is used, the evaluation metric If `verbose` is True and an evaluation set is used, the evaluation metric
measured on the validation set is printed to stdout at each boosting stage. measured on the validation set is printed to stdout at each boosting stage.
@ -1996,10 +1892,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
selected when colsample is being used. All values must be greater than 0, selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. otherwise a `ValueError` is thrown.
callbacks :
.. deprecated:: 1.6.0
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
""" """
with config_context(verbosity=self.verbosity): with config_context(verbosity=self.verbosity):
train_dmatrix, evals = _wrap_evaluation_matrices( train_dmatrix, evals = _wrap_evaluation_matrices(
@ -2024,27 +1916,19 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
evals_result: TrainingCallback.EvalsLog = {} evals_result: TrainingCallback.EvalsLog = {}
params = self.get_xgb_params() params = self.get_xgb_params()
( model, metric, params = self._configure_fit(xgb_model, params)
model,
metric,
params,
early_stopping_rounds,
callbacks,
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
self._Booster = train( self._Booster = train(
params, params,
train_dmatrix, train_dmatrix,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=self.early_stopping_rounds,
evals=evals, evals=evals,
evals_result=evals_result, evals_result=evals_result,
custom_metric=metric, custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
xgb_model=model, xgb_model=model,
callbacks=callbacks, callbacks=self.callbacks,
) )
self.objective = params["objective"] self.objective = params["objective"]

View File

@ -18,10 +18,12 @@ class LintersPaths:
"python-package/", "python-package/",
# tests # tests
"tests/python/test_config.py", "tests/python/test_config.py",
"tests/python/test_callback.py",
"tests/python/test_data_iterator.py", "tests/python/test_data_iterator.py",
"tests/python/test_dmatrix.py", "tests/python/test_dmatrix.py",
"tests/python/test_dt.py", "tests/python/test_dt.py",
"tests/python/test_demos.py", "tests/python/test_demos.py",
"tests/python/test_eval_metrics.py",
"tests/python/test_multi_target.py", "tests/python/test_multi_target.py",
"tests/python/test_predict.py", "tests/python/test_predict.py",
"tests/python/test_quantile_dmatrix.py", "tests/python/test_quantile_dmatrix.py",
@ -39,12 +41,15 @@ class LintersPaths:
"demo/dask/", "demo/dask/",
"demo/rmm_plugin", "demo/rmm_plugin",
"demo/json-model/json_parser.py", "demo/json-model/json_parser.py",
"demo/guide-python/continuation.py",
"demo/guide-python/cat_in_the_dat.py", "demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/callbacks.py", "demo/guide-python/callbacks.py",
"demo/guide-python/categorical.py", "demo/guide-python/categorical.py",
"demo/guide-python/cat_pipeline.py", "demo/guide-python/cat_pipeline.py",
"demo/guide-python/feature_weights.py", "demo/guide-python/feature_weights.py",
"demo/guide-python/sklearn_parallel.py", "demo/guide-python/sklearn_parallel.py",
"demo/guide-python/sklearn_examples.py",
"demo/guide-python/sklearn_evals_result.py",
"demo/guide-python/spark_estimator_examples.py", "demo/guide-python/spark_estimator_examples.py",
"demo/guide-python/external_memory.py", "demo/guide-python/external_memory.py",
"demo/guide-python/individual_trees.py", "demo/guide-python/individual_trees.py",
@ -93,6 +98,7 @@ class LintersPaths:
# demo # demo
"demo/json-model/json_parser.py", "demo/json-model/json_parser.py",
"demo/guide-python/external_memory.py", "demo/guide-python/external_memory.py",
"demo/guide-python/continuation.py",
"demo/guide-python/callbacks.py", "demo/guide-python/callbacks.py",
"demo/guide-python/cat_in_the_dat.py", "demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/categorical.py", "demo/guide-python/categorical.py",

View File

@ -16,13 +16,14 @@ class TestCallbacks:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
cls.X = X cls.X = X
cls.y = y cls.y = y
split = int(X.shape[0]*0.8) split = int(X.shape[0] * 0.8)
cls.X_train = X[: split, ...] cls.X_train = X[:split, ...]
cls.y_train = y[: split, ...] cls.y_train = y[:split, ...]
cls.X_valid = X[split:, ...] cls.X_valid = X[split:, ...]
cls.y_valid = y[split:, ...] cls.y_valid = y[split:, ...]
@ -31,31 +32,32 @@ class TestCallbacks:
D_train: xgb.DMatrix, D_train: xgb.DMatrix,
D_valid: xgb.DMatrix, D_valid: xgb.DMatrix,
rounds: int, rounds: int,
verbose_eval: Union[bool, int] verbose_eval: Union[bool, int],
): ):
def check_output(output: str) -> None: def check_output(output: str) -> None:
if int(verbose_eval) == 1: if int(verbose_eval) == 1:
# Should print each iteration info # Should print each iteration info
assert len(output.split('\n')) == rounds assert len(output.split("\n")) == rounds
elif int(verbose_eval) > rounds: elif int(verbose_eval) > rounds:
# Should print first and latest iteration info # Should print first and latest iteration info
assert len(output.split('\n')) == 2 assert len(output.split("\n")) == 2
else: else:
# Should print info by each period additionaly to first and latest # Should print info by each period additionaly to first and latest
# iteration # iteration
num_periods = rounds // int(verbose_eval) num_periods = rounds // int(verbose_eval)
# Extra information is required for latest iteration # Extra information is required for latest iteration
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1) is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
assert len(output.split('\n')) == ( assert len(output.split("\n")) == (
1 + num_periods + int(is_extra_info_required) 1 + num_periods + int(is_extra_info_required)
) )
evals_result: xgb.callback.TrainingCallback.EvalsLog = {} evals_result: xgb.callback.TrainingCallback.EvalsLog = {}
params = {'objective': 'binary:logistic', 'eval_metric': 'error'} params = {"objective": "binary:logistic", "eval_metric": "error"}
with tm.captured_output() as (out, err): with tm.captured_output() as (out, err):
xgb.train( xgb.train(
params, D_train, params,
evals=[(D_train, 'Train'), (D_valid, 'Valid')], D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=rounds, num_boost_round=rounds,
evals_result=evals_result, evals_result=evals_result,
verbose_eval=verbose_eval, verbose_eval=verbose_eval,
@ -73,14 +75,16 @@ class TestCallbacks:
D_valid = xgb.DMatrix(self.X_valid, self.y_valid) D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
evals_result = {} evals_result = {}
rounds = 10 rounds = 10
xgb.train({'objective': 'binary:logistic', xgb.train(
'eval_metric': 'error'}, D_train, {"objective": "binary:logistic", "eval_metric": "error"},
evals=[(D_train, 'Train'), (D_valid, 'Valid')], D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=rounds, num_boost_round=rounds,
evals_result=evals_result, evals_result=evals_result,
verbose_eval=True) verbose_eval=True,
assert len(evals_result['Train']['error']) == rounds )
assert len(evals_result['Valid']['error']) == rounds assert len(evals_result["Train"]["error"]) == rounds
assert len(evals_result["Valid"]["error"]) == rounds
self.run_evaluation_monitor(D_train, D_valid, rounds, True) self.run_evaluation_monitor(D_train, D_valid, rounds, True)
self.run_evaluation_monitor(D_train, D_valid, rounds, 2) self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
@ -93,72 +97,83 @@ class TestCallbacks:
evals_result = {} evals_result = {}
rounds = 30 rounds = 30
early_stopping_rounds = 5 early_stopping_rounds = 5
booster = xgb.train({'objective': 'binary:logistic', booster = xgb.train(
'eval_metric': 'error'}, D_train, {"objective": "binary:logistic", "eval_metric": "error"},
evals=[(D_train, 'Train'), (D_valid, 'Valid')], D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=rounds, num_boost_round=rounds,
evals_result=evals_result, evals_result=evals_result,
verbose_eval=True, verbose_eval=True,
early_stopping_rounds=early_stopping_rounds) early_stopping_rounds=early_stopping_rounds,
dump = booster.get_dump(dump_format='json') )
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_early_stopping_custom_eval(self): def test_early_stopping_custom_eval(self):
D_train = xgb.DMatrix(self.X_train, self.y_train) D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid) D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
early_stopping_rounds = 5 early_stopping_rounds = 5
booster = xgb.train({'objective': 'binary:logistic', booster = xgb.train(
'eval_metric': 'error', {
'tree_method': 'hist'}, D_train, "objective": "binary:logistic",
evals=[(D_train, 'Train'), (D_valid, 'Valid')], "eval_metric": "error",
"tree_method": "hist",
},
D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
feval=tm.eval_error_metric, feval=tm.eval_error_metric,
num_boost_round=1000, num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
verbose_eval=False) verbose_eval=False,
dump = booster.get_dump(dump_format='json') )
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_early_stopping_customize(self): def test_early_stopping_customize(self):
D_train = xgb.DMatrix(self.X_train, self.y_train) D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid) D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
early_stopping_rounds = 5 early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(
metric_name='CustomErr', rounds=early_stopping_rounds, metric_name="CustomErr", data_name="Train"
data_name='Train') )
# Specify which dataset and which metric should be used for early stopping. # Specify which dataset and which metric should be used for early stopping.
booster = xgb.train( booster = xgb.train(
{'objective': 'binary:logistic', {
'eval_metric': ['error', 'rmse'], "objective": "binary:logistic",
'tree_method': 'hist'}, D_train, "eval_metric": ["error", "rmse"],
evals=[(D_train, 'Train'), (D_valid, 'Valid')], "tree_method": "hist",
},
D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
feval=tm.eval_error_metric, feval=tm.eval_error_metric,
num_boost_round=1000, num_boost_round=1000,
callbacks=[early_stop], callbacks=[early_stop],
verbose_eval=False) verbose_eval=False,
dump = booster.get_dump(dump_format='json') )
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump) assert len(early_stop.stopping_history["Train"]["CustomErr"]) == len(dump)
rounds = 100 rounds = 100
early_stop = xgb.callback.EarlyStopping( early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, rounds=early_stopping_rounds,
metric_name='CustomErr', metric_name="CustomErr",
data_name='Train', data_name="Train",
min_delta=100, min_delta=100,
save_best=True, save_best=True,
) )
booster = xgb.train( booster = xgb.train(
{ {
'objective': 'binary:logistic', "objective": "binary:logistic",
'eval_metric': ['error', 'rmse'], "eval_metric": ["error", "rmse"],
'tree_method': 'hist' "tree_method": "hist",
}, },
D_train, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')], evals=[(D_train, "Train"), (D_valid, "Valid")],
feval=tm.eval_error_metric, feval=tm.eval_error_metric,
num_boost_round=rounds, num_boost_round=rounds,
callbacks=[early_stop], callbacks=[early_stop],
verbose_eval=False verbose_eval=False,
) )
# No iteration can be made with min_delta == 100 # No iteration can be made with min_delta == 100
assert booster.best_iteration == 0 assert booster.best_iteration == 0
@ -166,18 +181,20 @@ class TestCallbacks:
def test_early_stopping_skl(self): def test_early_stopping_skl(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
early_stopping_rounds = 5 early_stopping_rounds = 5
cls = xgb.XGBClassifier( cls = xgb.XGBClassifier(
early_stopping_rounds=early_stopping_rounds, eval_metric='error' early_stopping_rounds=early_stopping_rounds, eval_metric="error"
) )
cls.fit(X, y, eval_set=[(X, y)]) cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster() booster = cls.get_booster()
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_early_stopping_custom_eval_skl(self): def test_early_stopping_custom_eval_skl(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
early_stopping_rounds = 5 early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds) early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds)
@ -186,11 +203,12 @@ class TestCallbacks:
) )
cls.fit(X, y, eval_set=[(X, y)]) cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster() booster = cls.get_booster()
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_early_stopping_save_best_model(self): def test_early_stopping_save_best_model(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
n_estimators = 100 n_estimators = 100
early_stopping_rounds = 5 early_stopping_rounds = 5
@ -200,11 +218,11 @@ class TestCallbacks:
cls = xgb.XGBClassifier( cls = xgb.XGBClassifier(
n_estimators=n_estimators, n_estimators=n_estimators,
eval_metric=tm.eval_error_metric_skl, eval_metric=tm.eval_error_metric_skl,
callbacks=[early_stop] callbacks=[early_stop],
) )
cls.fit(X, y, eval_set=[(X, y)]) cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster() booster = cls.get_booster()
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format="json")
assert len(dump) == booster.best_iteration + 1 assert len(dump) == booster.best_iteration + 1
early_stop = xgb.callback.EarlyStopping( early_stop = xgb.callback.EarlyStopping(
@ -220,8 +238,9 @@ class TestCallbacks:
cls.fit(X, y, eval_set=[(X, y)]) cls.fit(X, y, eval_set=[(X, y)])
# No error # No error
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(
save_best=False) rounds=early_stopping_rounds, save_best=False
)
xgb.XGBClassifier( xgb.XGBClassifier(
booster="gblinear", booster="gblinear",
n_estimators=10, n_estimators=10,
@ -231,14 +250,17 @@ class TestCallbacks:
def test_early_stopping_continuation(self): def test_early_stopping_continuation(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
cls = xgb.XGBClassifier(eval_metric=tm.eval_error_metric_skl)
early_stopping_rounds = 5 early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping( early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True rounds=early_stopping_rounds, save_best=True
) )
with pytest.warns(UserWarning): cls = xgb.XGBClassifier(
cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) eval_metric=tm.eval_error_metric_skl, callbacks=[early_stop]
)
cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster() booster = cls.get_booster()
assert booster.num_boosted_rounds() == booster.best_iteration + 1 assert booster.num_boosted_rounds() == booster.best_iteration + 1
@ -256,21 +278,10 @@ class TestCallbacks:
) )
cls.fit(X, y, eval_set=[(X, y)]) cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster() booster = cls.get_booster()
assert booster.num_boosted_rounds() == \ assert (
booster.best_iteration + early_stopping_rounds + 1 booster.num_boosted_rounds()
== booster.best_iteration + early_stopping_rounds + 1
def test_deprecated(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
) )
clf = xgb.XGBClassifier(
eval_metric=tm.eval_error_metric_skl, callbacks=[early_stop]
)
with pytest.raises(ValueError, match=r".*set_params.*"):
clf.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
def run_eta_decay(self, tree_method): def run_eta_decay(self, tree_method):
"""Test learning rate scheduler, used by both CPU and GPU tests.""" """Test learning rate scheduler, used by both CPU and GPU tests."""
@ -343,7 +354,7 @@ class TestCallbacks:
callbacks=[scheduler([0, 0, 0, 0])], callbacks=[scheduler([0, 0, 0, 0])],
evals_result=evals_result, evals_result=evals_result,
) )
eval_errors_2 = list(map(float, evals_result['eval']['error'])) eval_errors_2 = list(map(float, evals_result["eval"]["error"]))
assert isinstance(bst, xgb.core.Booster) assert isinstance(bst, xgb.core.Booster)
# validation error should not decrease, if eta/learning_rate = 0 # validation error should not decrease, if eta/learning_rate = 0
assert eval_errors_2[0] == eval_errors_2[-1] assert eval_errors_2[0] == eval_errors_2[-1]
@ -361,7 +372,7 @@ class TestCallbacks:
callbacks=[scheduler(eta_decay)], callbacks=[scheduler(eta_decay)],
evals_result=evals_result, evals_result=evals_result,
) )
eval_errors_3 = list(map(float, evals_result['eval']['error'])) eval_errors_3 = list(map(float, evals_result["eval"]["error"]))
assert isinstance(bst, xgb.core.Booster) assert isinstance(bst, xgb.core.Booster)

View File

@ -15,23 +15,23 @@ class TestEarlyStopping:
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
digits = load_digits(n_class=2) digits = load_digits(n_class=2)
X = digits['data'] X = digits["data"]
y = digits['target'] y = digits["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf1 = xgb.XGBClassifier(learning_rate=0.1) clf1 = xgb.XGBClassifier(
clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc", learning_rate=0.1, early_stopping_rounds=5, eval_metric="auc"
eval_set=[(X_test, y_test)]) )
clf2 = xgb.XGBClassifier(learning_rate=0.1) clf1.fit(X_train, y_train, eval_set=[(X_test, y_test)])
clf2.fit(X_train, y_train, early_stopping_rounds=4, eval_metric="auc", clf2 = xgb.XGBClassifier(
eval_set=[(X_test, y_test)]) learning_rate=0.1, early_stopping_rounds=4, eval_metric="auc"
)
clf2.fit(X_train, y_train, eval_set=[(X_test, y_test)])
# should be the same # should be the same
assert clf1.best_score == clf2.best_score assert clf1.best_score == clf2.best_score
assert clf1.best_score != 1 assert clf1.best_score != 1
# check overfit # check overfit
clf3 = xgb.XGBClassifier( clf3 = xgb.XGBClassifier(
learning_rate=0.1, learning_rate=0.1, eval_metric="auc", early_stopping_rounds=10
eval_metric="auc",
early_stopping_rounds=10
) )
clf3.fit(X_train, y_train, eval_set=[(X_test, y_test)]) clf3.fit(X_train, y_train, eval_set=[(X_test, y_test)])
base_score = get_basescore(clf3) base_score = get_basescore(clf3)
@ -39,9 +39,9 @@ class TestEarlyStopping:
clf3 = xgb.XGBClassifier( clf3 = xgb.XGBClassifier(
learning_rate=0.1, learning_rate=0.1,
base_score=.5, base_score=0.5,
eval_metric="auc", eval_metric="auc",
early_stopping_rounds=10 early_stopping_rounds=10,
) )
clf3.fit(X_train, y_train, eval_set=[(X_test, y_test)]) clf3.fit(X_train, y_train, eval_set=[(X_test, y_test)])

View File

@ -9,37 +9,41 @@ rng = np.random.RandomState(1337)
class TestEvalMetrics: class TestEvalMetrics:
xgb_params_01 = {'nthread': 1, 'eval_metric': 'error'} xgb_params_01 = {"nthread": 1, "eval_metric": "error"}
xgb_params_02 = {'nthread': 1, 'eval_metric': ['error']} xgb_params_02 = {"nthread": 1, "eval_metric": ["error"]}
xgb_params_03 = {'nthread': 1, 'eval_metric': ['rmse', 'error']} xgb_params_03 = {"nthread": 1, "eval_metric": ["rmse", "error"]}
xgb_params_04 = {'nthread': 1, 'eval_metric': ['error', 'rmse']} xgb_params_04 = {"nthread": 1, "eval_metric": ["error", "rmse"]}
def evalerror_01(self, preds, dtrain): def evalerror_01(self, preds, dtrain):
labels = dtrain.get_label() labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels) return "error", float(sum(labels != (preds > 0.0))) / len(labels)
def evalerror_02(self, preds, dtrain): def evalerror_02(self, preds, dtrain):
labels = dtrain.get_label() labels = dtrain.get_label()
return [('error', float(sum(labels != (preds > 0.0))) / len(labels))] return [("error", float(sum(labels != (preds > 0.0))) / len(labels))]
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def evalerror_03(self, preds, dtrain): def evalerror_03(self, preds, dtrain):
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error
labels = dtrain.get_label() labels = dtrain.get_label()
return [('rmse', mean_squared_error(labels, preds)), return [
('error', float(sum(labels != (preds > 0.0))) / len(labels))] ("rmse", mean_squared_error(labels, preds)),
("error", float(sum(labels != (preds > 0.0))) / len(labels)),
]
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def evalerror_04(self, preds, dtrain): def evalerror_04(self, preds, dtrain):
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error
labels = dtrain.get_label() labels = dtrain.get_label()
return [('error', float(sum(labels != (preds > 0.0))) / len(labels)), return [
('rmse', mean_squared_error(labels, preds))] ("error", float(sum(labels != (preds > 0.0))) / len(labels)),
("rmse", mean_squared_error(labels, preds)),
]
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_eval_metrics(self): def test_eval_metrics(self):
@ -50,15 +54,15 @@ class TestEvalMetrics:
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
digits = load_digits(n_class=2) digits = load_digits(n_class=2)
X = digits['data'] X = digits["data"]
y = digits['target'] y = digits["target"]
Xt, Xv, yt, yv = train_test_split(X, y, test_size=0.2, random_state=0) Xt, Xv, yt, yv = train_test_split(X, y, test_size=0.2, random_state=0)
dtrain = xgb.DMatrix(Xt, label=yt) dtrain = xgb.DMatrix(Xt, label=yt)
dvalid = xgb.DMatrix(Xv, label=yv) dvalid = xgb.DMatrix(Xv, label=yv)
watchlist = [(dtrain, 'train'), (dvalid, 'val')] watchlist = [(dtrain, "train"), (dvalid, "val")]
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10) gbdt_01 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10)
gbdt_02 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=10) gbdt_02 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=10)
@ -66,26 +70,54 @@ class TestEvalMetrics:
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0] assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0] assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, 10, watchlist, gbdt_01 = xgb.train(
early_stopping_rounds=2) self.xgb_params_01, dtrain, 10, watchlist, early_stopping_rounds=2
gbdt_02 = xgb.train(self.xgb_params_02, dtrain, 10, watchlist, )
early_stopping_rounds=2) gbdt_02 = xgb.train(
gbdt_03 = xgb.train(self.xgb_params_03, dtrain, 10, watchlist, self.xgb_params_02, dtrain, 10, watchlist, early_stopping_rounds=2
early_stopping_rounds=2) )
gbdt_04 = xgb.train(self.xgb_params_04, dtrain, 10, watchlist, gbdt_03 = xgb.train(
early_stopping_rounds=2) self.xgb_params_03, dtrain, 10, watchlist, early_stopping_rounds=2
)
gbdt_04 = xgb.train(
self.xgb_params_04, dtrain, 10, watchlist, early_stopping_rounds=2
)
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0] assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0] assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0] assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, 10, watchlist, gbdt_01 = xgb.train(
early_stopping_rounds=2, feval=self.evalerror_01) self.xgb_params_01,
gbdt_02 = xgb.train(self.xgb_params_02, dtrain, 10, watchlist, dtrain,
early_stopping_rounds=2, feval=self.evalerror_02) 10,
gbdt_03 = xgb.train(self.xgb_params_03, dtrain, 10, watchlist, watchlist,
early_stopping_rounds=2, feval=self.evalerror_03) early_stopping_rounds=2,
gbdt_04 = xgb.train(self.xgb_params_04, dtrain, 10, watchlist, feval=self.evalerror_01,
early_stopping_rounds=2, feval=self.evalerror_04) )
gbdt_02 = xgb.train(
self.xgb_params_02,
dtrain,
10,
watchlist,
early_stopping_rounds=2,
feval=self.evalerror_02,
)
gbdt_03 = xgb.train(
self.xgb_params_03,
dtrain,
10,
watchlist,
early_stopping_rounds=2,
feval=self.evalerror_03,
)
gbdt_04 = xgb.train(
self.xgb_params_04,
dtrain,
10,
watchlist,
early_stopping_rounds=2,
feval=self.evalerror_04,
)
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0] assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0] assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0] assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]
@ -93,6 +125,7 @@ class TestEvalMetrics:
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_gamma_deviance(self): def test_gamma_deviance(self):
from sklearn.metrics import mean_gamma_deviance from sklearn.metrics import mean_gamma_deviance
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
n_samples = 100 n_samples = 100
n_features = 30 n_features = 30
@ -101,8 +134,13 @@ class TestEvalMetrics:
y = rng.randn(n_samples) y = rng.randn(n_samples)
y = y - y.min() * 100 y = y - y.min() * 100
reg = xgb.XGBRegressor(tree_method="hist", objective="reg:gamma", n_estimators=10) reg = xgb.XGBRegressor(
reg.fit(X, y, eval_metric="gamma-deviance") tree_method="hist",
objective="reg:gamma",
n_estimators=10,
eval_metric="gamma-deviance",
)
reg.fit(X, y)
booster = reg.get_booster() booster = reg.get_booster()
score = reg.predict(X) score = reg.predict(X)
@ -113,16 +151,26 @@ class TestEvalMetrics:
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_gamma_lik(self) -> None: def test_gamma_lik(self) -> None:
import scipy.stats as stats import scipy.stats as stats
rng = np.random.default_rng(1994) rng = np.random.default_rng(1994)
n_samples = 32 n_samples = 32
n_features = 10 n_features = 10
X = rng.normal(0, 1, size=n_samples * n_features).reshape((n_samples, n_features)) X = rng.normal(0, 1, size=n_samples * n_features).reshape(
(n_samples, n_features)
)
alpha, loc, beta = 5.0, 11.1, 22 alpha, loc, beta = 5.0, 11.1, 22
y = stats.gamma.rvs(alpha, loc=loc, scale=beta, size=n_samples, random_state=rng) y = stats.gamma.rvs(
reg = xgb.XGBRegressor(tree_method="hist", objective="reg:gamma", n_estimators=64) alpha, loc=loc, scale=beta, size=n_samples, random_state=rng
reg.fit(X, y, eval_metric="gamma-nloglik", eval_set=[(X, y)]) )
reg = xgb.XGBRegressor(
tree_method="hist",
objective="reg:gamma",
n_estimators=64,
eval_metric="gamma-nloglik",
)
reg.fit(X, y, eval_set=[(X, y)])
score = reg.predict(X) score = reg.predict(X)
@ -134,7 +182,7 @@ class TestEvalMetrics:
# XGBoost uses the canonical link function of gamma in evaluation function. # XGBoost uses the canonical link function of gamma in evaluation function.
# so \theta = - (1.0 / y) # so \theta = - (1.0 / y)
# dispersion is hardcoded as 1.0, so shape (a in scipy parameter) is also 1.0 # dispersion is hardcoded as 1.0, so shape (a in scipy parameter) is also 1.0
beta = - (1.0 / (- (1.0 / y))) # == y beta = -(1.0 / (-(1.0 / y))) # == y
nloglik_stats = -stats.gamma.logpdf(score, a=1.0, scale=beta) nloglik_stats = -stats.gamma.logpdf(score, a=1.0, scale=beta)
np.testing.assert_allclose(nloglik, np.mean(nloglik_stats), rtol=1e-3) np.testing.assert_allclose(nloglik, np.mean(nloglik_stats), rtol=1e-3)
@ -153,7 +201,7 @@ class TestEvalMetrics:
n_features, n_features,
n_informative=n_features, n_informative=n_features,
n_redundant=0, n_redundant=0,
random_state=rng random_state=rng,
) )
Xy = xgb.DMatrix(X, y) Xy = xgb.DMatrix(X, y)
booster = xgb.train( booster = xgb.train(
@ -197,7 +245,7 @@ class TestEvalMetrics:
n_informative=n_features, n_informative=n_features,
n_redundant=0, n_redundant=0,
n_classes=n_classes, n_classes=n_classes,
random_state=rng random_state=rng,
) )
if weighted: if weighted:
weights = rng.randn(n_samples) weights = rng.randn(n_samples)
@ -242,20 +290,25 @@ class TestEvalMetrics:
def run_pr_auc_binary(self, tree_method): def run_pr_auc_binary(self, tree_method):
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
from sklearn.metrics import auc, precision_recall_curve from sklearn.metrics import auc, precision_recall_curve
X, y = make_classification(128, 4, n_classes=2, random_state=1994) X, y = make_classification(128, 4, n_classes=2, random_state=1994)
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1) clf = xgb.XGBClassifier(
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)]) tree_method=tree_method, n_estimators=1, eval_metric="aucpr"
)
clf.fit(X, y, eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1] evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
y_score = clf.predict_proba(X)[:, 1] # get the positive column y_score = clf.predict_proba(X)[:, 1] # get the positive column
precision, recall, _ = precision_recall_curve(y, y_score) precision, recall, _ = precision_recall_curve(y, y_score)
prauc = auc(recall, precision) prauc = auc(recall, precision)
# Interpolation results are slightly different from sklearn, but overall should be # Interpolation results are slightly different from sklearn, but overall should
# similar. # be similar.
np.testing.assert_allclose(prauc, evals_result, rtol=1e-2) np.testing.assert_allclose(prauc, evals_result, rtol=1e-2)
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=10) clf = xgb.XGBClassifier(
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)]) tree_method=tree_method, n_estimators=10, eval_metric="aucpr"
)
clf.fit(X, y, eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1] evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
np.testing.assert_allclose(0.99, evals_result, rtol=1e-2) np.testing.assert_allclose(0.99, evals_result, rtol=1e-2)
@ -264,16 +317,21 @@ class TestEvalMetrics:
def run_pr_auc_multi(self, tree_method): def run_pr_auc_multi(self, tree_method):
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
X, y = make_classification( X, y = make_classification(
64, 16, n_informative=8, n_classes=3, random_state=1994 64, 16, n_informative=8, n_classes=3, random_state=1994
) )
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1) clf = xgb.XGBClassifier(
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)]) tree_method=tree_method, n_estimators=1, eval_metric="aucpr"
)
clf.fit(X, y, eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1] evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
# No available implementation for comparison, just check that XGBoost converges to # No available implementation for comparison, just check that XGBoost converges
# 1.0 # to 1.0
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=10) clf = xgb.XGBClassifier(
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)]) tree_method=tree_method, n_estimators=10, eval_metric="aucpr"
)
clf.fit(X, y, eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1] evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
np.testing.assert_allclose(1.0, evals_result, rtol=1e-2) np.testing.assert_allclose(1.0, evals_result, rtol=1e-2)
@ -282,9 +340,13 @@ class TestEvalMetrics:
def run_pr_auc_ltr(self, tree_method): def run_pr_auc_ltr(self, tree_method):
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
X, y = make_classification(128, 4, n_classes=2, random_state=1994) X, y = make_classification(128, 4, n_classes=2, random_state=1994)
ltr = xgb.XGBRanker( ltr = xgb.XGBRanker(
tree_method=tree_method, n_estimators=16, objective="rank:pairwise" tree_method=tree_method,
n_estimators=16,
objective="rank:pairwise",
eval_metric="aucpr",
) )
groups = np.array([32, 32, 64]) groups = np.array([32, 32, 64])
ltr.fit( ltr.fit(
@ -293,7 +355,6 @@ class TestEvalMetrics:
group=groups, group=groups,
eval_set=[(X, y)], eval_set=[(X, y)],
eval_group=[groups], eval_group=[groups],
eval_metric="aucpr",
) )
results = ltr.evals_result()["validation_0"]["aucpr"] results = ltr.evals_result()["validation_0"]["aucpr"]
assert results[-1] >= 0.99 assert results[-1] >= 0.99

View File

@ -149,8 +149,8 @@ class TestTrainingContinuation:
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
clf = xgb.XGBClassifier(n_estimators=2) clf = xgb.XGBClassifier(n_estimators=2, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss") clf.fit(X, y, eval_set=[(X, y)])
assert tm.non_increasing(clf.evals_result()["validation_0"]["logloss"]) assert tm.non_increasing(clf.evals_result()["validation_0"]["logloss"])
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
@ -160,5 +160,6 @@ class TestTrainingContinuation:
clf = xgb.XGBClassifier(n_estimators=2) clf = xgb.XGBClassifier(n_estimators=2)
# change metric to error # change metric to error
clf.fit(X, y, eval_set=[(X, y)], eval_metric="error") clf.set_params(eval_metric="error")
clf.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)
assert tm.non_increasing(clf.evals_result()["validation_0"]["error"]) assert tm.non_increasing(clf.evals_result()["validation_0"]["error"])

View File

@ -30,8 +30,8 @@ def test_binary_classification():
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier): for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
for train_index, test_index in kf.split(X, y): for train_index, test_index in kf.split(X, y):
clf = cls(random_state=42) clf = cls(random_state=42, eval_metric=['auc', 'logloss'])
xgb_model = clf.fit(X[train_index], y[train_index], eval_metric=['auc', 'logloss']) xgb_model = clf.fit(X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index]) preds = xgb_model.predict(X[test_index])
labels = y[test_index] labels = y[test_index]
err = sum(1 for i in range(len(preds)) err = sum(1 for i in range(len(preds))
@ -101,10 +101,11 @@ def test_best_iteration():
def train(booster: str, forest: Optional[int]) -> None: def train(booster: str, forest: Optional[int]) -> None:
rounds = 4 rounds = 4
cls = xgb.XGBClassifier( cls = xgb.XGBClassifier(
n_estimators=rounds, num_parallel_tree=forest, booster=booster n_estimators=rounds,
).fit( num_parallel_tree=forest,
X, y, eval_set=[(X, y)], early_stopping_rounds=3 booster=booster,
) early_stopping_rounds=3,
).fit(X, y, eval_set=[(X, y)])
assert cls.best_iteration == rounds - 1 assert cls.best_iteration == rounds - 1
# best_iteration is used by default, assert that under gblinear it's # best_iteration is used by default, assert that under gblinear it's
@ -112,9 +113,9 @@ def test_best_iteration():
cls.predict(X) cls.predict(X)
num_parallel_tree = 4 num_parallel_tree = 4
train('gbtree', num_parallel_tree) train("gbtree", num_parallel_tree)
train('dart', num_parallel_tree) train("dart", num_parallel_tree)
train('gblinear', None) train("gblinear", None)
def test_ranking(): def test_ranking():
@ -258,6 +259,7 @@ def test_stacking_classification():
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
clf.fit(X_train, y_train).score(X_test, y_test) clf.fit(X_train, y_train).score(X_test, y_test)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_feature_importances_weight(): def test_feature_importances_weight():
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
@ -474,7 +476,8 @@ def run_housing_rf_regression(tree_method):
rfreg = xgb.XGBRFRegressor() rfreg = xgb.XGBRFRegressor()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
rfreg.fit(X, y, early_stopping_rounds=10) rfreg.set_params(early_stopping_rounds=10)
rfreg.fit(X, y)
def test_rf_regression(): def test_rf_regression():
@ -844,51 +847,65 @@ def run_validation_weights(model):
y_train, y_test = y[:1600], y[1600:] y_train, y_test = y[:1600], y[1600:]
# instantiate model # instantiate model
param_dist = {'objective': 'binary:logistic', 'n_estimators': 2, param_dist = {
'random_state': 123} "objective": "binary:logistic",
"n_estimators": 2,
"random_state": 123,
}
clf = model(**param_dist) clf = model(**param_dist)
# train it using instance weights only in the training set # train it using instance weights only in the training set
weights_train = np.random.choice([1, 2], len(X_train)) weights_train = np.random.choice([1, 2], len(X_train))
clf.fit(X_train, y_train, clf.set_params(eval_metric="logloss")
clf.fit(
X_train,
y_train,
sample_weight=weights_train, sample_weight=weights_train,
eval_set=[(X_test, y_test)], eval_set=[(X_test, y_test)],
eval_metric='logloss', verbose=False,
verbose=False) )
# evaluate logloss metric on test set *without* using weights # evaluate logloss metric on test set *without* using weights
evals_result_without_weights = clf.evals_result() evals_result_without_weights = clf.evals_result()
logloss_without_weights = evals_result_without_weights[ logloss_without_weights = evals_result_without_weights["validation_0"]["logloss"]
"validation_0"]["logloss"]
# now use weights for the test set # now use weights for the test set
np.random.seed(0) np.random.seed(0)
weights_test = np.random.choice([1, 2], len(X_test)) weights_test = np.random.choice([1, 2], len(X_test))
clf.fit(X_train, y_train, clf.set_params(eval_metric="logloss")
clf.fit(
X_train,
y_train,
sample_weight=weights_train, sample_weight=weights_train,
eval_set=[(X_test, y_test)], eval_set=[(X_test, y_test)],
sample_weight_eval_set=[weights_test], sample_weight_eval_set=[weights_test],
eval_metric='logloss', verbose=False,
verbose=False) )
evals_result_with_weights = clf.evals_result() evals_result_with_weights = clf.evals_result()
logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"] logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"]
# check that the logloss in the test set is actually different when using # check that the logloss in the test set is actually different when using
# weights than when not using them # weights than when not using them
assert all((logloss_with_weights[i] != logloss_without_weights[i] assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1]))
for i in [0, 1]))
with pytest.raises(ValueError): with pytest.raises(ValueError):
# length of eval set and sample weight doesn't match. # length of eval set and sample weight doesn't match.
clf.fit(X_train, y_train, sample_weight=weights_train, clf.fit(
X_train,
y_train,
sample_weight=weights_train,
eval_set=[(X_train, y_train), (X_test, y_test)], eval_set=[(X_train, y_train), (X_test, y_test)],
sample_weight_eval_set=[weights_train]) sample_weight_eval_set=[weights_train],
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cls = xgb.XGBClassifier() cls = xgb.XGBClassifier()
cls.fit(X_train, y_train, sample_weight=weights_train, cls.fit(
X_train,
y_train,
sample_weight=weights_train,
eval_set=[(X_train, y_train), (X_test, y_test)], eval_set=[(X_train, y_train), (X_test, y_test)],
sample_weight_eval_set=[weights_train]) sample_weight_eval_set=[weights_train],
)
def test_validation_weights(): def test_validation_weights():
@ -960,8 +977,7 @@ def test_XGBClassifier_resume():
# file name of stored xgb model # file name of stored xgb model
model1.save_model(model1_path) model1.save_model(model1_path)
model2 = xgb.XGBClassifier( model2 = xgb.XGBClassifier(learning_rate=0.3, random_state=0, n_estimators=8)
learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_path) model2.fit(X, Y, xgb_model=model1_path)
pred2 = model2.predict(X) pred2 = model2.predict(X)
@ -972,8 +988,7 @@ def test_XGBClassifier_resume():
# file name of 'Booster' instance Xgb model # file name of 'Booster' instance Xgb model
model1.get_booster().save_model(model1_booster_path) model1.get_booster().save_model(model1_booster_path)
model2 = xgb.XGBClassifier( model2 = xgb.XGBClassifier(learning_rate=0.3, random_state=0, n_estimators=8)
learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_booster_path) model2.fit(X, Y, xgb_model=model1_booster_path)
pred2 = model2.predict(X) pred2 = model2.predict(X)
@ -1279,12 +1294,16 @@ def test_estimator_reg(estimator, check):
): ):
estimator.fit(X, y) estimator.fit(X, y)
return return
if os.environ["PYTEST_CURRENT_TEST"].find("check_estimators_overwrite_params") != -1: if (
os.environ["PYTEST_CURRENT_TEST"].find("check_estimators_overwrite_params")
!= -1
):
# A hack to pass the scikit-learn parameter mutation tests. XGBoost regressor # A hack to pass the scikit-learn parameter mutation tests. XGBoost regressor
# returns actual internal default values for parameters in `get_params`, but those # returns actual internal default values for parameters in `get_params`, but
# are set as `None` in sklearn interface to avoid duplication. So we fit a dummy # those are set as `None` in sklearn interface to avoid duplication. So we fit
# model and obtain the default parameters here for the mutation tests. # a dummy model and obtain the default parameters here for the mutation tests.
from sklearn.datasets import make_regression from sklearn.datasets import make_regression
X, y = make_regression(n_samples=2, n_features=1) X, y = make_regression(n_samples=2, n_features=1)
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params()) estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())
@ -1325,6 +1344,7 @@ def test_categorical():
def test_evaluation_metric(): def test_evaluation_metric():
from sklearn.datasets import load_diabetes, load_digits from sklearn.datasets import load_diabetes, load_digits
from sklearn.metrics import mean_absolute_error from sklearn.metrics import mean_absolute_error
X, y = load_diabetes(return_X_y=True) X, y = load_diabetes(return_X_y=True)
n_estimators = 16 n_estimators = 16
@ -1341,17 +1361,6 @@ def test_evaluation_metric():
for line in lines: for line in lines:
assert line.find("mean_absolute_error") != -1 assert line.find("mean_absolute_error") != -1
def metric(predt: np.ndarray, Xy: xgb.DMatrix):
y = Xy.get_label()
return "m", np.abs(predt - y).sum()
with pytest.warns(UserWarning):
reg = xgb.XGBRegressor(
tree_method="hist",
n_estimators=1,
)
reg.fit(X, y, eval_set=[(X, y)], eval_metric=metric)
def merror(y_true: np.ndarray, predt: np.ndarray): def merror(y_true: np.ndarray, predt: np.ndarray):
n_samples = y_true.shape[0] n_samples = y_true.shape[0]
assert n_samples == predt.size assert n_samples == predt.size

View File

@ -363,12 +363,12 @@ class TestDistributedGPU:
device="cuda", device="cuda",
eval_metric="error", eval_metric="error",
n_estimators=100, n_estimators=100,
early_stopping_rounds=early_stopping_rounds,
) )
cls.client = local_cuda_client cls.client = local_cuda_client
cls.fit( cls.fit(
X, X,
y, y,
early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)], eval_set=[(valid_X, valid_y)],
) )
booster = cls.get_booster() booster = cls.get_booster()

View File

@ -937,8 +937,10 @@ def run_empty_dmatrix_auc(client: "Client", device: str, n_workers: int) -> None
valid_X = dd.from_array(valid_X_, chunksize=n_samples) valid_X = dd.from_array(valid_X_, chunksize=n_samples)
valid_y = dd.from_array(valid_y_, chunksize=n_samples) valid_y = dd.from_array(valid_y_, chunksize=n_samples)
cls = xgb.dask.DaskXGBClassifier(device=device, n_estimators=2) cls = xgb.dask.DaskXGBClassifier(
cls.fit(X, y, eval_metric=["auc", "aucpr"], eval_set=[(valid_X, valid_y)]) device=device, n_estimators=2, eval_metric=["auc", "aucpr"]
)
cls.fit(X, y, eval_set=[(valid_X, valid_y)])
# multiclass # multiclass
X_, y_ = make_classification( X_, y_ = make_classification(
@ -966,8 +968,10 @@ def run_empty_dmatrix_auc(client: "Client", device: str, n_workers: int) -> None
valid_X = dd.from_array(valid_X_, chunksize=n_samples) valid_X = dd.from_array(valid_X_, chunksize=n_samples)
valid_y = dd.from_array(valid_y_, chunksize=n_samples) valid_y = dd.from_array(valid_y_, chunksize=n_samples)
cls = xgb.dask.DaskXGBClassifier(device=device, n_estimators=2) cls = xgb.dask.DaskXGBClassifier(
cls.fit(X, y, eval_metric=["auc", "aucpr"], eval_set=[(valid_X, valid_y)]) device=device, n_estimators=2, eval_metric=["auc", "aucpr"]
)
cls.fit(X, y, eval_set=[(valid_X, valid_y)])
def test_empty_dmatrix_auc() -> None: def test_empty_dmatrix_auc() -> None:
@ -994,11 +998,11 @@ def run_auc(client: "Client", device: str) -> None:
valid_X = dd.from_array(valid_X_, chunksize=10) valid_X = dd.from_array(valid_X_, chunksize=10)
valid_y = dd.from_array(valid_y_, chunksize=10) valid_y = dd.from_array(valid_y_, chunksize=10)
cls = xgb.XGBClassifier(device=device, n_estimators=2) cls = xgb.XGBClassifier(device=device, n_estimators=2, eval_metric="auc")
cls.fit(X_, y_, eval_metric="auc", eval_set=[(valid_X_, valid_y_)]) cls.fit(X_, y_, eval_set=[(valid_X_, valid_y_)])
dcls = xgb.dask.DaskXGBClassifier(device=device, n_estimators=2) dcls = xgb.dask.DaskXGBClassifier(device=device, n_estimators=2, eval_metric="auc")
dcls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)]) dcls.fit(X, y, eval_set=[(valid_X, valid_y)])
approx = dcls.evals_result()["validation_0"]["auc"] approx = dcls.evals_result()["validation_0"]["auc"]
exact = cls.evals_result()["validation_0"]["auc"] exact = cls.evals_result()["validation_0"]["auc"]
@ -1267,16 +1271,16 @@ def test_dask_ranking(client: "Client") -> None:
qid_valid = qid_valid.astype(np.uint32) qid_valid = qid_valid.astype(np.uint32)
qid_test = qid_test.astype(np.uint32) qid_test = qid_test.astype(np.uint32)
rank = xgb.dask.DaskXGBRanker(n_estimators=2500) rank = xgb.dask.DaskXGBRanker(
n_estimators=2500, eval_metric=["ndcg"], early_stopping_rounds=10
)
rank.fit( rank.fit(
x_train, x_train,
y_train, y_train,
qid=qid_train, qid=qid_train,
eval_set=[(x_test, y_test), (x_train, y_train)], eval_set=[(x_test, y_test), (x_train, y_train)],
eval_qid=[qid_test, qid_train], eval_qid=[qid_test, qid_train],
eval_metric=["ndcg"],
verbose=True, verbose=True,
early_stopping_rounds=10,
) )
assert rank.n_features_in_ == 46 assert rank.n_features_in_ == 46
assert rank.best_score > 0.98 assert rank.best_score > 0.98
@ -2150,13 +2154,15 @@ class TestDaskCallbacks:
valid_X, valid_y = load_breast_cancer(return_X_y=True) valid_X, valid_y = load_breast_cancer(return_X_y=True)
valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y) valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y)
cls = xgb.dask.DaskXGBClassifier( cls = xgb.dask.DaskXGBClassifier(
objective="binary:logistic", tree_method="hist", n_estimators=1000 objective="binary:logistic",
tree_method="hist",
n_estimators=1000,
early_stopping_rounds=early_stopping_rounds,
) )
cls.client = client cls.client = client
cls.fit( cls.fit(
X, X,
y, y,
early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)], eval_set=[(valid_X, valid_y)],
) )
booster = cls.get_booster() booster = cls.get_booster()
@ -2165,15 +2171,17 @@ class TestDaskCallbacks:
# Specify the metric # Specify the metric
cls = xgb.dask.DaskXGBClassifier( cls = xgb.dask.DaskXGBClassifier(
objective="binary:logistic", tree_method="hist", n_estimators=1000 objective="binary:logistic",
tree_method="hist",
n_estimators=1000,
early_stopping_rounds=early_stopping_rounds,
eval_metric="error",
) )
cls.client = client cls.client = client
cls.fit( cls.fit(
X, X,
y, y,
early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)], eval_set=[(valid_X, valid_y)],
eval_metric="error",
) )
assert tm.non_increasing(cls.evals_result()["validation_0"]["error"]) assert tm.non_increasing(cls.evals_result()["validation_0"]["error"])
booster = cls.get_booster() booster = cls.get_booster()
@ -2215,12 +2223,12 @@ class TestDaskCallbacks:
tree_method="hist", tree_method="hist",
n_estimators=1000, n_estimators=1000,
eval_metric=tm.eval_error_metric_skl, eval_metric=tm.eval_error_metric_skl,
early_stopping_rounds=early_stopping_rounds,
) )
cls.client = client cls.client = client
cls.fit( cls.fit(
X, X,
y, y,
early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)], eval_set=[(valid_X, valid_y)],
) )
booster = cls.get_booster() booster = cls.get_booster()
@ -2234,21 +2242,22 @@ class TestDaskCallbacks:
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y) X, y = da.from_array(X), da.from_array(y)
cls = xgb.dask.DaskXGBClassifier(
objective="binary:logistic", tree_method="hist", n_estimators=10
)
cls.client = client
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
cls.fit( cls = xgb.dask.DaskXGBClassifier(
X, objective="binary:logistic",
y, tree_method="hist",
n_estimators=10,
callbacks=[ callbacks=[
xgb.callback.TrainingCheckPoint( xgb.callback.TrainingCheckPoint(
directory=Path(tmpdir), interval=1, name="model" directory=Path(tmpdir), interval=1, name="model"
) )
], ],
) )
cls.client = client
cls.fit(
X,
y,
)
for i in range(1, 10): for i in range(1, 10):
assert os.path.exists( assert os.path.exists(
os.path.join( os.path.join(

View File

@ -311,24 +311,20 @@ def clf_with_weight(
y_val = np.array([0, 1]) y_val = np.array([0, 1])
w_train = np.array([1.0, 2.0]) w_train = np.array([1.0, 2.0])
w_val = np.array([1.0, 2.0]) w_val = np.array([1.0, 2.0])
cls2 = XGBClassifier() cls2 = XGBClassifier(eval_metric="logloss", early_stopping_rounds=1)
cls2.fit( cls2.fit(
X_train, X_train,
y_train, y_train,
eval_set=[(X_val, y_val)], eval_set=[(X_val, y_val)],
early_stopping_rounds=1,
eval_metric="logloss",
) )
cls3 = XGBClassifier() cls3 = XGBClassifier(eval_metric="logloss", early_stopping_rounds=1)
cls3.fit( cls3.fit(
X_train, X_train,
y_train, y_train,
sample_weight=w_train, sample_weight=w_train,
eval_set=[(X_val, y_val)], eval_set=[(X_val, y_val)],
sample_weight_eval_set=[w_val], sample_weight_eval_set=[w_val],
early_stopping_rounds=1,
eval_metric="logloss",
) )
cls_df_train_with_eval_weight = spark.createDataFrame( cls_df_train_with_eval_weight = spark.createDataFrame(