[dask] Enable gridsearching with skl. (#5417)
This commit is contained in:
@@ -145,6 +145,25 @@ def test_dask_classifier():
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_sklearn_grid_search():
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
with LocalCluster(n_workers=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
|
||||
tree_method='hist')
|
||||
reg.client = client
|
||||
model = GridSearchCV(reg, {'max_depth': [2, 4],
|
||||
'n_estimators': [5, 10]},
|
||||
cv=2, verbose=1, iid=True)
|
||||
model.fit(X, y)
|
||||
# Expect unique results for each parameter value This confirms
|
||||
# sklearn is able to successfully update the parameter
|
||||
means = model.cv_results_['mean_test_score']
|
||||
assert len(means) == len(set(means))
|
||||
|
||||
|
||||
def run_empty_dmatrix(client, parameters):
|
||||
|
||||
def _check_outputs(out, predictions):
|
||||
|
||||
Reference in New Issue
Block a user