From d6386e45e8ed19c238ee06544dabbcbf56e02bbc Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 5 Dec 2020 17:20:36 +0800 Subject: [PATCH] Fix filtering callable objects in skl xgb param. (#6466) Co-authored-by: Hyunsu Cho --- python-package/xgboost/sklearn.py | 2 +- tests/python/test_with_sklearn.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 2703b8160..d3b2a1bf8 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -398,7 +398,7 @@ class XGBModel(XGBModelBase): 'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'} filtered = dict() for k, v in params.items(): - if k not in wrapper_specific: + if k not in wrapper_specific and not callable(v): filtered[k] = v return filtered diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 318c349f3..8a4f17ffb 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -399,6 +399,21 @@ def test_classification_with_custom_objective(): X, y ) + cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1) + cls.fit(X, y) + + is_called = [False] + + def wrapped(y, p): + is_called[0] = True + return logregobj(y, p) + + cls.set_params(objective=wrapped) + cls.predict(X) # no throw + cls.fit(X, y) + + assert is_called[0] + def test_sklearn_api(): from sklearn.datasets import load_iris