Support dask dataframe as y for classifier. (#5077)
* Support dask dataframe as y for classifier. * Lint.
This commit is contained in:
parent
b9dbfe0931
commit
761e938dbe
@ -632,7 +632,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self.classes_ = da.unique(y).compute()
|
if isinstance(y, (da.Array)):
|
||||||
|
self.classes_ = da.unique(y).compute()
|
||||||
|
else:
|
||||||
|
self.classes_ = y.drop_duplicates().compute()
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
|
|
||||||
if self.n_classes_ > 2:
|
if self.n_classes_ > 2:
|
||||||
@ -640,7 +643,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
params['num_class'] = self.n_classes_
|
params['num_class'] = self.n_classes_
|
||||||
else:
|
else:
|
||||||
params["objective"] = "binary:logistic"
|
params["objective"] = "binary:logistic"
|
||||||
params.setdefault('num_class', self.n_classes_)
|
|
||||||
|
|
||||||
evals = _evaluation_matrices(self.client,
|
evals = _evaluation_matrices(self.client,
|
||||||
eval_set, sample_weight_eval_set)
|
eval_set, sample_weight_eval_set)
|
||||||
|
|||||||
@ -109,3 +109,16 @@ def test_classifier(client):
|
|||||||
assert list(history['validation_0'].keys())[0] == 'merror'
|
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||||
assert len(list(history['validation_0'])) == 1
|
assert len(list(history['validation_0'])) == 1
|
||||||
assert len(history['validation_0']['merror']) == 2
|
assert len(history['validation_0']['merror']) == 2
|
||||||
|
|
||||||
|
assert classifier.n_classes_ == 10
|
||||||
|
|
||||||
|
# Test with dataframe.
|
||||||
|
X_d = dd.from_dask_array(X)
|
||||||
|
y_d = dd.from_dask_array(y)
|
||||||
|
classifier.fit(X_d, y_d)
|
||||||
|
|
||||||
|
assert classifier.n_classes_ == 10
|
||||||
|
prediction = classifier.predict(X_d)
|
||||||
|
|
||||||
|
assert prediction.ndim == 1
|
||||||
|
assert prediction.shape[0] == kRows
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user