[dask] Enable gridsearching with skl. (#5417)

This commit is contained in:
Jiaming Yuan 2020-03-16 04:51:51 +08:00 committed by GitHub
parent 761a5dbdfc
commit b51124c158
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 2 deletions

View File

@ -30,7 +30,8 @@ from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
from .core import DMatrix, Booster, _expect from .core import DMatrix, Booster, _expect
from .training import train as worker_train from .training import train as worker_train
from .tracker import RabitTracker from .tracker import RabitTracker
from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import xgboost_model_doc
# Current status is considered as initial support, many features are # Current status is considered as initial support, many features are
# not properly supported yet. # not properly supported yet.
@ -639,7 +640,7 @@ class DaskScikitLearnBase(XGBModel):
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""", @xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model']) ['estimators', 'model'])
class DaskXGBRegressor(DaskScikitLearnBase): class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
def fit(self, def fit(self,
X, X,

View File

@ -145,6 +145,25 @@ def test_dask_classifier():
assert prediction.shape[0] == kRows assert prediction.shape[0] == kRows
@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_grid_search():
from sklearn.model_selection import GridSearchCV
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
X, y = generate_array()
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
tree_method='hist')
reg.client = client
model = GridSearchCV(reg, {'max_depth': [2, 4],
'n_estimators': [5, 10]},
cv=2, verbose=1, iid=True)
model.fit(X, y)
# Expect unique results for each parameter value This confirms
# sklearn is able to successfully update the parameter
means = model.cv_results_['mean_test_score']
assert len(means) == len(set(means))
def run_empty_dmatrix(client, parameters): def run_empty_dmatrix(client, parameters):
def _check_outputs(out, predictions): def _check_outputs(out, predictions):