Support sample weight in sklearn custom objective. (#10050)

This commit is contained in:
Jiaming Yuan
2024-02-21 00:43:14 +08:00
committed by GitHub
parent 69a17d5114
commit 8ea705e4d5
6 changed files with 179 additions and 69 deletions

View File

@@ -517,6 +517,12 @@ def test_regression_with_custom_objective():
labels = y[test_index]
assert mean_squared_error(preds, labels) < 25
w = rng.uniform(low=0.0, high=1.0, size=X.shape[0])
reg = xgb.XGBRegressor(objective=tm.ls_obj, n_estimators=25)
reg.fit(X, y, sample_weight=w)
y_pred = reg.predict(X)
assert mean_squared_error(y_true=y, y_pred=y_pred, sample_weight=w) < 25
# Test that the custom objective function is actually used
class XGBCustomObjectiveException(Exception):
pass