[dask] Enable gridsearching with skl. (#5417)
This commit is contained in:
@@ -30,7 +30,8 @@ from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
|
||||
from .core import DMatrix, Booster, _expect
|
||||
from .training import train as worker_train
|
||||
from .tracker import RabitTracker
|
||||
from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||
from .sklearn import xgboost_model_doc
|
||||
|
||||
# Current status is considered as initial support, many features are
|
||||
# not properly supported yet.
|
||||
@@ -639,7 +640,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||
['estimators', 'model'])
|
||||
class DaskXGBRegressor(DaskScikitLearnBase):
|
||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
# pylint: disable=missing-docstring
|
||||
def fit(self,
|
||||
X,
|
||||
|
||||
Reference in New Issue
Block a user