[dask] Enable gridsearching with skl. (#5417)
This commit is contained in:
parent
761a5dbdfc
commit
b51124c158
@ -30,7 +30,8 @@ from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
|
|||||||
from .core import DMatrix, Booster, _expect
|
from .core import DMatrix, Booster, _expect
|
||||||
from .training import train as worker_train
|
from .training import train as worker_train
|
||||||
from .tracker import RabitTracker
|
from .tracker import RabitTracker
|
||||||
from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc
|
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||||
|
from .sklearn import xgboost_model_doc
|
||||||
|
|
||||||
# Current status is considered as initial support, many features are
|
# Current status is considered as initial support, many features are
|
||||||
# not properly supported yet.
|
# not properly supported yet.
|
||||||
@ -639,7 +640,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
|
|
||||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||||
['estimators', 'model'])
|
['estimators', 'model'])
|
||||||
class DaskXGBRegressor(DaskScikitLearnBase):
|
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
def fit(self,
|
def fit(self,
|
||||||
X,
|
X,
|
||||||
|
|||||||
@ -145,6 +145,25 @@ def test_dask_classifier():
|
|||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
def test_sklearn_grid_search():
|
||||||
|
from sklearn.model_selection import GridSearchCV
|
||||||
|
with LocalCluster(n_workers=4) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
X, y = generate_array()
|
||||||
|
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
|
||||||
|
tree_method='hist')
|
||||||
|
reg.client = client
|
||||||
|
model = GridSearchCV(reg, {'max_depth': [2, 4],
|
||||||
|
'n_estimators': [5, 10]},
|
||||||
|
cv=2, verbose=1, iid=True)
|
||||||
|
model.fit(X, y)
|
||||||
|
# Expect unique results for each parameter value This confirms
|
||||||
|
# sklearn is able to successfully update the parameter
|
||||||
|
means = model.cv_results_['mean_test_score']
|
||||||
|
assert len(means) == len(set(means))
|
||||||
|
|
||||||
|
|
||||||
def run_empty_dmatrix(client, parameters):
|
def run_empty_dmatrix(client, parameters):
|
||||||
|
|
||||||
def _check_outputs(out, predictions):
|
def _check_outputs(out, predictions):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user