Allow sklearn grid search over parameters specified as kwargs (#3791)
This commit is contained in:
parent
1db28b8718
commit
5d6baed998
@ -181,6 +181,27 @@ class XGBModel(XGBModelBase):
|
|||||||
raise XGBoostError('need to call fit or load_model beforehand')
|
raise XGBoostError('need to call fit or load_model beforehand')
|
||||||
return self._Booster
|
return self._Booster
|
||||||
|
|
||||||
|
def set_params(self, **params):
|
||||||
|
"""Set the parameters of this estimator.
|
||||||
|
Modification of the sklearn method to allow unknown kwargs. This allows using
|
||||||
|
the full range of xgboost parameters that are not defined as member variables
|
||||||
|
in sklearn grid search.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
self
|
||||||
|
"""
|
||||||
|
if not params:
|
||||||
|
# Simple optimization to gain speed (inspect is slow)
|
||||||
|
return self
|
||||||
|
|
||||||
|
for key, value in params.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
setattr(self, key, value)
|
||||||
|
else:
|
||||||
|
self.kwargs[key] = value
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def get_params(self, deep=False):
|
def get_params(self, deep=False):
|
||||||
"""Get parameters."""
|
"""Get parameters."""
|
||||||
params = super(XGBModel, self).get_params(deep=deep)
|
params = super(XGBModel, self).get_params(deep=deep)
|
||||||
|
|||||||
@ -396,6 +396,27 @@ def test_kwargs():
|
|||||||
assert clf.get_params()['n_estimators'] == 1000
|
assert clf.get_params()['n_estimators'] == 1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_kwargs_grid_search():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
from sklearn.model_selection import GridSearchCV
|
||||||
|
from sklearn import datasets
|
||||||
|
|
||||||
|
params = {'tree_method': 'hist'}
|
||||||
|
clf = xgb.XGBClassifier(n_estimators=1, learning_rate=1.0, **params)
|
||||||
|
assert clf.get_params()['tree_method'] == 'hist'
|
||||||
|
# 'max_leaves' is not a default argument of XGBClassifier
|
||||||
|
# Check we can still do grid search over this parameter
|
||||||
|
search_params = {'max_leaves': range(2, 5)}
|
||||||
|
grid_cv = GridSearchCV(clf, search_params, cv=5)
|
||||||
|
iris = datasets.load_iris()
|
||||||
|
grid_cv.fit(iris.data, iris.target)
|
||||||
|
|
||||||
|
# Expect unique results for each parameter value
|
||||||
|
# This confirms sklearn is able to successfully update the parameter
|
||||||
|
means = grid_cv.cv_results_['mean_test_score']
|
||||||
|
assert len(means) == len(set(means))
|
||||||
|
|
||||||
|
|
||||||
@raises(TypeError)
|
@raises(TypeError)
|
||||||
def test_kwargs_error():
|
def test_kwargs_error():
|
||||||
tm._skip_if_no_sklearn()
|
tm._skip_if_no_sklearn()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user