Add tests for pickling with custom obj and metric. (#9943)

This commit is contained in:
Jiaming Yuan
2024-01-04 14:52:48 +08:00
committed by GitHub
parent 26a5436a65
commit 5f7b5a6921
6 changed files with 65 additions and 28 deletions

View File

@@ -504,15 +504,10 @@ def test_regression_with_custom_objective():
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
def objective_ls(y_true, y_pred):
grad = (y_pred - y_true)
hess = np.ones(len(y_true))
return grad, hess
X, y = fetch_california_housing(return_X_y=True)
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBRegressor(objective=objective_ls).fit(
xgb_model = xgb.XGBRegressor(objective=tm.ls_obj).fit(
X[train_index], y[train_index]
)
preds = xgb_model.predict(X[test_index])
@@ -530,27 +525,29 @@ def test_regression_with_custom_objective():
np.testing.assert_raises(XGBCustomObjectiveException, xgb_model.fit, X, y)
def logregobj(y_true, y_pred):
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
grad = y_pred - y_true
hess = y_pred * (1.0 - y_pred)
return grad, hess
def test_classification_with_custom_objective():
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold
def logregobj(y_true, y_pred):
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
grad = y_pred - y_true
hess = y_pred * (1.0 - y_pred)
return grad, hess
digits = load_digits(n_class=2)
y = digits['target']
X = digits['data']
y = digits["target"]
X = digits["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier(objective=logregobj)
xgb_model.fit(X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
err = sum(
1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]
) / float(len(preds))
assert err < 0.1
# Test that the custom objective function is actually used