Add tests for pickling with custom obj and metric. (#9943)

This commit is contained in:
Jiaming Yuan 2024-01-04 14:52:48 +08:00 committed by GitHub
parent 26a5436a65
commit 5f7b5a6921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 65 additions and 28 deletions

View File

@ -279,7 +279,6 @@ available at :ref:`sphx_glr_python_examples_custom_softmax.py`. Also, see
Scikit-Learn Interface Scikit-Learn Interface
********************** **********************
The scikit-learn interface of XGBoost has some utilities to improve the integration with The scikit-learn interface of XGBoost has some utilities to improve the integration with
standard scikit-learn functions. For instance, after XGBoost 1.6.0 users can use the cost standard scikit-learn functions. For instance, after XGBoost 1.6.0 users can use the cost
function (not scoring functions) from scikit-learn out of the box: function (not scoring functions) from scikit-learn out of the box:

View File

@ -101,6 +101,8 @@ snapshot generated by an earlier version of XGBoost may result in errors or unde
**If a model is persisted with** ``pickle.dump`` (Python) or ``saveRDS`` (R), **then the model may **If a model is persisted with** ``pickle.dump`` (Python) or ``saveRDS`` (R), **then the model may
not be accessible in later versions of XGBoost.** not be accessible in later versions of XGBoost.**
.. _custom-obj-metric:
*************************** ***************************
Custom objective and metric Custom objective and metric
*************************** ***************************

View File

@ -192,11 +192,16 @@ __model_doc = f"""
Boosting learning rate (xgb's "eta") Boosting learning rate (xgb's "eta")
verbosity : Optional[int] verbosity : Optional[int]
The degree of verbosity. Valid values are 0 (silent) - 3 (debug). The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : {SklObjective} objective : {SklObjective}
Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below). Specify the learning task and the corresponding learning objective or a custom
objective function to be used. For custom objective, see
:doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
information.
booster: Optional[str] booster: Optional[str]
Specify which booster to use: gbtree, gblinear or dart. Specify which booster to use: `gbtree`, `gblinear` or `dart`.
tree_method: Optional[str] tree_method: Optional[str]
Specify which tree method to use. Default to auto. If this parameter is set to Specify which tree method to use. Default to auto. If this parameter is set to
default, XGBoost will choose the most conservative option available. It's default, XGBoost will choose the most conservative option available. It's
@ -328,21 +333,21 @@ __model_doc = f"""
Metric used for monitoring the training result and early stopping. It can be a Metric used for monitoring the training result and early stopping. It can be a
string or list of strings as names of predefined metric in XGBoost (See string or list of strings as names of predefined metric in XGBoost (See
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any other doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any
user defined metric that looks like `sklearn.metrics`. other user defined metric that looks like `sklearn.metrics`.
If custom objective is also provided, then custom metric should implement the If custom objective is also provided, then custom metric should implement the
corresponding reverse link function. corresponding reverse link function.
Unlike the `scoring` parameter commonly used in scikit-learn, when a callable Unlike the `scoring` parameter commonly used in scikit-learn, when a callable
object is provided, it's assumed to be a cost function and by default XGBoost will object is provided, it's assumed to be a cost function and by default XGBoost
minimize the result during early stopping. will minimize the result during early stopping.
For advanced usage on Early stopping like directly choosing to maximize instead of For advanced usage on Early stopping like directly choosing to maximize instead
minimize, see :py:obj:`xgboost.callback.EarlyStopping`. of minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
See :doc:`Custom Objective and Evaluation Metric </tutorials/custom_metric_obj>` See :doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
for more. information.
.. note:: .. note::

View File

@ -815,6 +815,13 @@ def softprob_obj(
return objective return objective
def ls_obj(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Least squared error."""
grad = y_pred - y_true
hess = np.ones(len(y_true))
return grad, hess
class DirectoryExcursion: class DirectoryExcursion:
"""Change directory. Change back and optionally cleaning up the directory when """Change directory. Change back and optionally cleaning up the directory when
exit. exit.

View File

@ -1,10 +1,13 @@
import json import json
import os import os
import pickle import pickle
import tempfile
import numpy as np import numpy as np
import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm
kRows = 100 kRows = 100
kCols = 10 kCols = 10
@ -61,3 +64,27 @@ class TestPickling:
params = {"nthread": 8, "tree_method": "exact", "subsample": 0.5} params = {"nthread": 8, "tree_method": "exact", "subsample": 0.5}
config = self.run_model_pickling(params) config = self.run_model_pickling(params)
check(config) check(config)
@pytest.mark.skipif(**tm.no_sklearn())
def test_with_sklearn_obj_metric(self) -> None:
from sklearn.metrics import mean_squared_error
X, y = tm.datasets.make_regression()
reg = xgb.XGBRegressor(objective=tm.ls_obj, eval_metric=mean_squared_error)
reg.fit(X, y)
pkl = pickle.dumps(reg)
reg_1 = pickle.loads(pkl)
assert callable(reg_1.objective)
assert callable(reg_1.eval_metric)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.json")
reg.save_model(path)
reg_2 = xgb.XGBRegressor()
reg_2.load_model(path)
assert not callable(reg_2.objective)
assert not callable(reg_2.eval_metric)
assert reg_2.eval_metric is None

View File

@ -504,15 +504,10 @@ def test_regression_with_custom_objective():
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
def objective_ls(y_true, y_pred):
grad = (y_pred - y_true)
hess = np.ones(len(y_true))
return grad, hess
X, y = fetch_california_housing(return_X_y=True) X, y = fetch_california_housing(return_X_y=True)
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y): for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBRegressor(objective=objective_ls).fit( xgb_model = xgb.XGBRegressor(objective=tm.ls_obj).fit(
X[train_index], y[train_index] X[train_index], y[train_index]
) )
preds = xgb_model.predict(X[test_index]) preds = xgb_model.predict(X[test_index])
@ -530,27 +525,29 @@ def test_regression_with_custom_objective():
np.testing.assert_raises(XGBCustomObjectiveException, xgb_model.fit, X, y) np.testing.assert_raises(XGBCustomObjectiveException, xgb_model.fit, X, y)
def logregobj(y_true, y_pred):
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
grad = y_pred - y_true
hess = y_pred * (1.0 - y_pred)
return grad, hess
def test_classification_with_custom_objective(): def test_classification_with_custom_objective():
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
def logregobj(y_true, y_pred):
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
grad = y_pred - y_true
hess = y_pred * (1.0 - y_pred)
return grad, hess
digits = load_digits(n_class=2) digits = load_digits(n_class=2)
y = digits['target'] y = digits["target"]
X = digits['data'] X = digits["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y): for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier(objective=logregobj) xgb_model = xgb.XGBClassifier(objective=logregobj)
xgb_model.fit(X[train_index], y[train_index]) xgb_model.fit(X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index]) preds = xgb_model.predict(X[test_index])
labels = y[test_index] labels = y[test_index]
err = sum(1 for i in range(len(preds)) err = sum(
if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) 1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]
) / float(len(preds))
assert err < 0.1 assert err < 0.1
# Test that the custom objective function is actually used # Test that the custom objective function is actually used