Support dask dataframe as y for classifier. (#5077)

* Support dask dataframe as y for classifier.

* Lint.
This commit is contained in:
Jiaming Yuan 2019-12-02 11:53:30 +08:00 committed by GitHub
parent b9dbfe0931
commit 761e938dbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 2 deletions

View File

@ -632,7 +632,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
params = self.get_xgb_params() params = self.get_xgb_params()
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
self.classes_ = da.unique(y).compute() if isinstance(y, (da.Array)):
self.classes_ = da.unique(y).compute()
else:
self.classes_ = y.drop_duplicates().compute()
self.n_classes_ = len(self.classes_) self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2: if self.n_classes_ > 2:
@ -640,7 +643,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
params['num_class'] = self.n_classes_ params['num_class'] = self.n_classes_
else: else:
params["objective"] = "binary:logistic" params["objective"] = "binary:logistic"
params.setdefault('num_class', self.n_classes_)
evals = _evaluation_matrices(self.client, evals = _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set) eval_set, sample_weight_eval_set)

View File

@ -109,3 +109,16 @@ def test_classifier(client):
assert list(history['validation_0'].keys())[0] == 'merror' assert list(history['validation_0'].keys())[0] == 'merror'
assert len(list(history['validation_0'])) == 1 assert len(list(history['validation_0'])) == 1
assert len(history['validation_0']['merror']) == 2 assert len(history['validation_0']['merror']) == 2
assert classifier.n_classes_ == 10
# Test with dataframe.
X_d = dd.from_dask_array(X)
y_d = dd.from_dask_array(y)
classifier.fit(X_d, y_d)
assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d)
assert prediction.ndim == 1
assert prediction.shape[0] == kRows