Support dask dataframe as y for classifier. (#5077)

* Support dask dataframe as y for classifier.

* Lint.
This commit is contained in:
Jiaming Yuan
2019-12-02 11:53:30 +08:00
committed by GitHub
parent b9dbfe0931
commit 761e938dbe
2 changed files with 17 additions and 2 deletions

View File

@@ -632,7 +632,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
params = self.get_xgb_params()
# pylint: disable=attribute-defined-outside-init
self.classes_ = da.unique(y).compute()
if isinstance(y, (da.Array)):
self.classes_ = da.unique(y).compute()
else:
self.classes_ = y.drop_duplicates().compute()
self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2:
@@ -640,7 +643,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
params['num_class'] = self.n_classes_
else:
params["objective"] = "binary:logistic"
params.setdefault('num_class', self.n_classes_)
evals = _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set)