Fix filtering callable objects in skl xgb param. (#6466)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
05e5563c2c
commit
d6386e45e8
@ -398,7 +398,7 @@ class XGBModel(XGBModelBase):
|
|||||||
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'}
|
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'}
|
||||||
filtered = dict()
|
filtered = dict()
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
if k not in wrapper_specific:
|
if k not in wrapper_specific and not callable(v):
|
||||||
filtered[k] = v
|
filtered[k] = v
|
||||||
return filtered
|
return filtered
|
||||||
|
|
||||||
|
|||||||
@ -399,6 +399,21 @@ def test_classification_with_custom_objective():
|
|||||||
X, y
|
X, y
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1)
|
||||||
|
cls.fit(X, y)
|
||||||
|
|
||||||
|
is_called = [False]
|
||||||
|
|
||||||
|
def wrapped(y, p):
|
||||||
|
is_called[0] = True
|
||||||
|
return logregobj(y, p)
|
||||||
|
|
||||||
|
cls.set_params(objective=wrapped)
|
||||||
|
cls.predict(X) # no throw
|
||||||
|
cls.fit(X, y)
|
||||||
|
|
||||||
|
assert is_called[0]
|
||||||
|
|
||||||
|
|
||||||
def test_sklearn_api():
|
def test_sklearn_api():
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user