Add tests for pickling with custom obj and metric. (#9943)
This commit is contained in:
parent
26a5436a65
commit
5f7b5a6921
@ -279,7 +279,6 @@ available at :ref:`sphx_glr_python_examples_custom_softmax.py`. Also, see
|
|||||||
Scikit-Learn Interface
|
Scikit-Learn Interface
|
||||||
**********************
|
**********************
|
||||||
|
|
||||||
|
|
||||||
The scikit-learn interface of XGBoost has some utilities to improve the integration with
|
The scikit-learn interface of XGBoost has some utilities to improve the integration with
|
||||||
standard scikit-learn functions. For instance, after XGBoost 1.6.0 users can use the cost
|
standard scikit-learn functions. For instance, after XGBoost 1.6.0 users can use the cost
|
||||||
function (not scoring functions) from scikit-learn out of the box:
|
function (not scoring functions) from scikit-learn out of the box:
|
||||||
|
|||||||
@ -101,6 +101,8 @@ snapshot generated by an earlier version of XGBoost may result in errors or unde
|
|||||||
**If a model is persisted with** ``pickle.dump`` (Python) or ``saveRDS`` (R), **then the model may
|
**If a model is persisted with** ``pickle.dump`` (Python) or ``saveRDS`` (R), **then the model may
|
||||||
not be accessible in later versions of XGBoost.**
|
not be accessible in later versions of XGBoost.**
|
||||||
|
|
||||||
|
.. _custom-obj-metric:
|
||||||
|
|
||||||
***************************
|
***************************
|
||||||
Custom objective and metric
|
Custom objective and metric
|
||||||
***************************
|
***************************
|
||||||
|
|||||||
@ -192,11 +192,16 @@ __model_doc = f"""
|
|||||||
Boosting learning rate (xgb's "eta")
|
Boosting learning rate (xgb's "eta")
|
||||||
verbosity : Optional[int]
|
verbosity : Optional[int]
|
||||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
||||||
|
|
||||||
objective : {SklObjective}
|
objective : {SklObjective}
|
||||||
Specify the learning task and the corresponding learning objective or
|
|
||||||
a custom objective function to be used (see note below).
|
Specify the learning task and the corresponding learning objective or a custom
|
||||||
|
objective function to be used. For custom objective, see
|
||||||
|
:doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
|
||||||
|
information.
|
||||||
|
|
||||||
booster: Optional[str]
|
booster: Optional[str]
|
||||||
Specify which booster to use: gbtree, gblinear or dart.
|
Specify which booster to use: `gbtree`, `gblinear` or `dart`.
|
||||||
tree_method: Optional[str]
|
tree_method: Optional[str]
|
||||||
Specify which tree method to use. Default to auto. If this parameter is set to
|
Specify which tree method to use. Default to auto. If this parameter is set to
|
||||||
default, XGBoost will choose the most conservative option available. It's
|
default, XGBoost will choose the most conservative option available. It's
|
||||||
@ -328,21 +333,21 @@ __model_doc = f"""
|
|||||||
|
|
||||||
Metric used for monitoring the training result and early stopping. It can be a
|
Metric used for monitoring the training result and early stopping. It can be a
|
||||||
string or list of strings as names of predefined metric in XGBoost (See
|
string or list of strings as names of predefined metric in XGBoost (See
|
||||||
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any other
|
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any
|
||||||
user defined metric that looks like `sklearn.metrics`.
|
other user defined metric that looks like `sklearn.metrics`.
|
||||||
|
|
||||||
If custom objective is also provided, then custom metric should implement the
|
If custom objective is also provided, then custom metric should implement the
|
||||||
corresponding reverse link function.
|
corresponding reverse link function.
|
||||||
|
|
||||||
Unlike the `scoring` parameter commonly used in scikit-learn, when a callable
|
Unlike the `scoring` parameter commonly used in scikit-learn, when a callable
|
||||||
object is provided, it's assumed to be a cost function and by default XGBoost will
|
object is provided, it's assumed to be a cost function and by default XGBoost
|
||||||
minimize the result during early stopping.
|
will minimize the result during early stopping.
|
||||||
|
|
||||||
For advanced usage on Early stopping like directly choosing to maximize instead of
|
For advanced usage on Early stopping like directly choosing to maximize instead
|
||||||
minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
|
of minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
|
||||||
|
|
||||||
See :doc:`Custom Objective and Evaluation Metric </tutorials/custom_metric_obj>`
|
See :doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
|
||||||
for more.
|
information.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
|
|||||||
@ -815,6 +815,13 @@ def softprob_obj(
|
|||||||
return objective
|
return objective
|
||||||
|
|
||||||
|
|
||||||
|
def ls_obj(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Least squared error."""
|
||||||
|
grad = y_pred - y_true
|
||||||
|
hess = np.ones(len(y_true))
|
||||||
|
return grad, hess
|
||||||
|
|
||||||
|
|
||||||
class DirectoryExcursion:
|
class DirectoryExcursion:
|
||||||
"""Change directory. Change back and optionally cleaning up the directory when
|
"""Change directory. Change back and optionally cleaning up the directory when
|
||||||
exit.
|
exit.
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
from xgboost import testing as tm
|
||||||
|
|
||||||
kRows = 100
|
kRows = 100
|
||||||
kCols = 10
|
kCols = 10
|
||||||
@ -61,3 +64,27 @@ class TestPickling:
|
|||||||
params = {"nthread": 8, "tree_method": "exact", "subsample": 0.5}
|
params = {"nthread": 8, "tree_method": "exact", "subsample": 0.5}
|
||||||
config = self.run_model_pickling(params)
|
config = self.run_model_pickling(params)
|
||||||
check(config)
|
check(config)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
def test_with_sklearn_obj_metric(self) -> None:
|
||||||
|
from sklearn.metrics import mean_squared_error
|
||||||
|
|
||||||
|
X, y = tm.datasets.make_regression()
|
||||||
|
reg = xgb.XGBRegressor(objective=tm.ls_obj, eval_metric=mean_squared_error)
|
||||||
|
reg.fit(X, y)
|
||||||
|
|
||||||
|
pkl = pickle.dumps(reg)
|
||||||
|
reg_1 = pickle.loads(pkl)
|
||||||
|
assert callable(reg_1.objective)
|
||||||
|
assert callable(reg_1.eval_metric)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
path = os.path.join(tmpdir, "model.json")
|
||||||
|
reg.save_model(path)
|
||||||
|
|
||||||
|
reg_2 = xgb.XGBRegressor()
|
||||||
|
reg_2.load_model(path)
|
||||||
|
|
||||||
|
assert not callable(reg_2.objective)
|
||||||
|
assert not callable(reg_2.eval_metric)
|
||||||
|
assert reg_2.eval_metric is None
|
||||||
|
|||||||
@ -504,15 +504,10 @@ def test_regression_with_custom_objective():
|
|||||||
from sklearn.metrics import mean_squared_error
|
from sklearn.metrics import mean_squared_error
|
||||||
from sklearn.model_selection import KFold
|
from sklearn.model_selection import KFold
|
||||||
|
|
||||||
def objective_ls(y_true, y_pred):
|
|
||||||
grad = (y_pred - y_true)
|
|
||||||
hess = np.ones(len(y_true))
|
|
||||||
return grad, hess
|
|
||||||
|
|
||||||
X, y = fetch_california_housing(return_X_y=True)
|
X, y = fetch_california_housing(return_X_y=True)
|
||||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||||
for train_index, test_index in kf.split(X, y):
|
for train_index, test_index in kf.split(X, y):
|
||||||
xgb_model = xgb.XGBRegressor(objective=objective_ls).fit(
|
xgb_model = xgb.XGBRegressor(objective=tm.ls_obj).fit(
|
||||||
X[train_index], y[train_index]
|
X[train_index], y[train_index]
|
||||||
)
|
)
|
||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
@ -530,27 +525,29 @@ def test_regression_with_custom_objective():
|
|||||||
np.testing.assert_raises(XGBCustomObjectiveException, xgb_model.fit, X, y)
|
np.testing.assert_raises(XGBCustomObjectiveException, xgb_model.fit, X, y)
|
||||||
|
|
||||||
|
|
||||||
def test_classification_with_custom_objective():
|
def logregobj(y_true, y_pred):
|
||||||
from sklearn.datasets import load_digits
|
|
||||||
from sklearn.model_selection import KFold
|
|
||||||
|
|
||||||
def logregobj(y_true, y_pred):
|
|
||||||
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
|
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
|
||||||
grad = y_pred - y_true
|
grad = y_pred - y_true
|
||||||
hess = y_pred * (1.0 - y_pred)
|
hess = y_pred * (1.0 - y_pred)
|
||||||
return grad, hess
|
return grad, hess
|
||||||
|
|
||||||
|
|
||||||
|
def test_classification_with_custom_objective():
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
from sklearn.model_selection import KFold
|
||||||
|
|
||||||
digits = load_digits(n_class=2)
|
digits = load_digits(n_class=2)
|
||||||
y = digits['target']
|
y = digits["target"]
|
||||||
X = digits['data']
|
X = digits["data"]
|
||||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||||
for train_index, test_index in kf.split(X, y):
|
for train_index, test_index in kf.split(X, y):
|
||||||
xgb_model = xgb.XGBClassifier(objective=logregobj)
|
xgb_model = xgb.XGBClassifier(objective=logregobj)
|
||||||
xgb_model.fit(X[train_index], y[train_index])
|
xgb_model.fit(X[train_index], y[train_index])
|
||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
labels = y[test_index]
|
labels = y[test_index]
|
||||||
err = sum(1 for i in range(len(preds))
|
err = sum(
|
||||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]
|
||||||
|
) / float(len(preds))
|
||||||
assert err < 0.1
|
assert err < 0.1
|
||||||
|
|
||||||
# Test that the custom objective function is actually used
|
# Test that the custom objective function is actually used
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user