Fix filtering callable objects in skl xgb param. (#6466)

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2020-12-05 17:20:36 +08:00 committed by GitHub
parent 05e5563c2c
commit d6386e45e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 1 deletions

View File

@ -398,7 +398,7 @@ class XGBModel(XGBModelBase):
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'}
filtered = dict()
for k, v in params.items():
if k not in wrapper_specific:
if k not in wrapper_specific and not callable(v):
filtered[k] = v
return filtered

View File

@ -399,6 +399,21 @@ def test_classification_with_custom_objective():
X, y
)
cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1)
cls.fit(X, y)
is_called = [False]
def wrapped(y, p):
is_called[0] = True
return logregobj(y, p)
cls.set_params(objective=wrapped)
cls.predict(X) # no throw
cls.fit(X, y)
assert is_called[0]
def test_sklearn_api():
from sklearn.datasets import load_iris