Allow sklearn grid search over parameters specified as kwargs (#3791)

This commit is contained in:
Rory Mitchell 2018-10-14 12:44:53 +13:00 committed by GitHub
parent 1db28b8718
commit 5d6baed998
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 0 deletions

View File

@ -181,6 +181,27 @@ class XGBModel(XGBModelBase):
raise XGBoostError('need to call fit or load_model beforehand') raise XGBoostError('need to call fit or load_model beforehand')
return self._Booster return self._Booster
def set_params(self, **params):
"""Set the parameters of this estimator.
Modification of the sklearn method to allow unknown kwargs. This allows using
the full range of xgboost parameters that are not defined as member variables
in sklearn grid search.
Returns
-------
self
"""
if not params:
# Simple optimization to gain speed (inspect is slow)
return self
for key, value in params.items():
if hasattr(self, key):
setattr(self, key, value)
else:
self.kwargs[key] = value
return self
def get_params(self, deep=False): def get_params(self, deep=False):
"""Get parameters.""" """Get parameters."""
params = super(XGBModel, self).get_params(deep=deep) params = super(XGBModel, self).get_params(deep=deep)

View File

@ -396,6 +396,27 @@ def test_kwargs():
assert clf.get_params()['n_estimators'] == 1000 assert clf.get_params()['n_estimators'] == 1000
def test_kwargs_grid_search():
tm._skip_if_no_sklearn()
from sklearn.model_selection import GridSearchCV
from sklearn import datasets
params = {'tree_method': 'hist'}
clf = xgb.XGBClassifier(n_estimators=1, learning_rate=1.0, **params)
assert clf.get_params()['tree_method'] == 'hist'
# 'max_leaves' is not a default argument of XGBClassifier
# Check we can still do grid search over this parameter
search_params = {'max_leaves': range(2, 5)}
grid_cv = GridSearchCV(clf, search_params, cv=5)
iris = datasets.load_iris()
grid_cv.fit(iris.data, iris.target)
# Expect unique results for each parameter value
# This confirms sklearn is able to successfully update the parameter
means = grid_cv.cv_results_['mean_test_score']
assert len(means) == len(set(means))
@raises(TypeError) @raises(TypeError)
def test_kwargs_error(): def test_kwargs_error():
tm._skip_if_no_sklearn() tm._skip_if_no_sklearn()