From 5d6baed998b3a3287006148670d10ec04565f5f1 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sun, 14 Oct 2018 12:44:53 +1300 Subject: [PATCH] Allow sklearn grid search over parameters specified as kwargs (#3791) --- python-package/xgboost/sklearn.py | 21 +++++++++++++++++++++ tests/python/test_with_sklearn.py | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 047aa15f2..fd6d96d54 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -181,6 +181,27 @@ class XGBModel(XGBModelBase): raise XGBoostError('need to call fit or load_model beforehand') return self._Booster + def set_params(self, **params): + """Set the parameters of this estimator. + Modification of the sklearn method to allow unknown kwargs. This allows using + the full range of xgboost parameters that are not defined as member variables + in sklearn grid search. + Returns + ------- + self + """ + if not params: + # Simple optimization to gain speed (inspect is slow) + return self + + for key, value in params.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + self.kwargs[key] = value + + return self + def get_params(self, deep=False): """Get parameters.""" params = super(XGBModel, self).get_params(deep=deep) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index c5ceb282e..38e19922a 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -396,6 +396,27 @@ def test_kwargs(): assert clf.get_params()['n_estimators'] == 1000 +def test_kwargs_grid_search(): + tm._skip_if_no_sklearn() + from sklearn.model_selection import GridSearchCV + from sklearn import datasets + + params = {'tree_method': 'hist'} + clf = xgb.XGBClassifier(n_estimators=1, learning_rate=1.0, **params) + assert clf.get_params()['tree_method'] == 'hist' + # 'max_leaves' is not a default argument of XGBClassifier + # Check we can still do grid search over this parameter + search_params = {'max_leaves': range(2, 5)} + grid_cv = GridSearchCV(clf, search_params, cv=5) + iris = datasets.load_iris() + grid_cv.fit(iris.data, iris.target) + + # Expect unique results for each parameter value + # This confirms sklearn is able to successfully update the parameter + means = grid_cv.cv_results_['mean_test_score'] + assert len(means) == len(set(means)) + + @raises(TypeError) def test_kwargs_error(): tm._skip_if_no_sklearn()