Fix filtering callable objects in skl xgb param. (#6466)

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-12-05 17:20:36 +08:00
committed by Hyunsu Cho
parent 2b3e301543
commit c39f6b25f0
2 changed files with 16 additions and 1 deletions

View File

@@ -399,6 +399,21 @@ def test_classification_with_custom_objective():
X, y
)
cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1)
cls.fit(X, y)
is_called = [False]
def wrapped(y, p):
is_called[0] = True
return logregobj(y, p)
cls.set_params(objective=wrapped)
cls.predict(X) # no throw
cls.fit(X, y)
assert is_called[0]
def test_sklearn_api():
from sklearn.datasets import load_iris