diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index bd581b6a8..a06a25502 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -30,7 +30,8 @@ from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat from .core import DMatrix, Booster, _expect from .training import train as worker_train from .tracker import RabitTracker -from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc +from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase +from .sklearn import xgboost_model_doc # Current status is considered as initial support, many features are # not properly supported yet. @@ -639,7 +640,7 @@ class DaskScikitLearnBase(XGBModel): @xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""", ['estimators', 'model']) -class DaskXGBRegressor(DaskScikitLearnBase): +class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): # pylint: disable=missing-docstring def fit(self, X, diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 225232710..f579ee5d7 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -145,6 +145,25 @@ def test_dask_classifier(): assert prediction.shape[0] == kRows +@pytest.mark.skipif(**tm.no_sklearn()) +def test_sklearn_grid_search(): + from sklearn.model_selection import GridSearchCV + with LocalCluster(n_workers=4) as cluster: + with Client(cluster) as client: + X, y = generate_array() + reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1, + tree_method='hist') + reg.client = client + model = GridSearchCV(reg, {'max_depth': [2, 4], + 'n_estimators': [5, 10]}, + cv=2, verbose=1, iid=True) + model.fit(X, y) + # Expect unique results for each parameter value This confirms + # sklearn is able to successfully update the parameter + means = model.cv_results_['mean_test_score'] + assert len(means) == len(set(means)) + + def run_empty_dmatrix(client, parameters): def _check_outputs(out, predictions):