Support sample weight in sklearn custom objective. (#10050)
This commit is contained in:
@@ -517,6 +517,12 @@ def test_regression_with_custom_objective():
|
||||
labels = y[test_index]
|
||||
assert mean_squared_error(preds, labels) < 25
|
||||
|
||||
w = rng.uniform(low=0.0, high=1.0, size=X.shape[0])
|
||||
reg = xgb.XGBRegressor(objective=tm.ls_obj, n_estimators=25)
|
||||
reg.fit(X, y, sample_weight=w)
|
||||
y_pred = reg.predict(X)
|
||||
assert mean_squared_error(y_true=y, y_pred=y_pred, sample_weight=w) < 25
|
||||
|
||||
# Test that the custom objective function is actually used
|
||||
class XGBCustomObjectiveException(Exception):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user