Fix filtering callable objects in skl xgb param. (#6466)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
05e5563c2c
commit
d6386e45e8
@ -398,7 +398,7 @@ class XGBModel(XGBModelBase):
|
||||
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'}
|
||||
filtered = dict()
|
||||
for k, v in params.items():
|
||||
if k not in wrapper_specific:
|
||||
if k not in wrapper_specific and not callable(v):
|
||||
filtered[k] = v
|
||||
return filtered
|
||||
|
||||
|
||||
@ -399,6 +399,21 @@ def test_classification_with_custom_objective():
|
||||
X, y
|
||||
)
|
||||
|
||||
cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1)
|
||||
cls.fit(X, y)
|
||||
|
||||
is_called = [False]
|
||||
|
||||
def wrapped(y, p):
|
||||
is_called[0] = True
|
||||
return logregobj(y, p)
|
||||
|
||||
cls.set_params(objective=wrapped)
|
||||
cls.predict(X) # no throw
|
||||
cls.fit(X, y)
|
||||
|
||||
assert is_called[0]
|
||||
|
||||
|
||||
def test_sklearn_api():
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user