Allow sklearn grid search over parameters specified as kwargs (#3791)

This commit is contained in:
Rory Mitchell
2018-10-14 12:44:53 +13:00
committed by GitHub
parent 1db28b8718
commit 5d6baed998
2 changed files with 42 additions and 0 deletions

View File

@@ -396,6 +396,27 @@ def test_kwargs():
assert clf.get_params()['n_estimators'] == 1000
def test_kwargs_grid_search():
tm._skip_if_no_sklearn()
from sklearn.model_selection import GridSearchCV
from sklearn import datasets
params = {'tree_method': 'hist'}
clf = xgb.XGBClassifier(n_estimators=1, learning_rate=1.0, **params)
assert clf.get_params()['tree_method'] == 'hist'
# 'max_leaves' is not a default argument of XGBClassifier
# Check we can still do grid search over this parameter
search_params = {'max_leaves': range(2, 5)}
grid_cv = GridSearchCV(clf, search_params, cv=5)
iris = datasets.load_iris()
grid_cv.fit(iris.data, iris.target)
# Expect unique results for each parameter value
# This confirms sklearn is able to successfully update the parameter
means = grid_cv.cv_results_['mean_test_score']
assert len(means) == len(set(means))
@raises(TypeError)
def test_kwargs_error():
tm._skip_if_no_sklearn()