Support dask dataframe as y for classifier. (#5077)
* Support dask dataframe as y for classifier. * Lint.
This commit is contained in:
parent
b9dbfe0931
commit
761e938dbe
@ -632,7 +632,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
params = self.get_xgb_params()
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.classes_ = da.unique(y).compute()
|
||||
if isinstance(y, (da.Array)):
|
||||
self.classes_ = da.unique(y).compute()
|
||||
else:
|
||||
self.classes_ = y.drop_duplicates().compute()
|
||||
self.n_classes_ = len(self.classes_)
|
||||
|
||||
if self.n_classes_ > 2:
|
||||
@ -640,7 +643,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
params['num_class'] = self.n_classes_
|
||||
else:
|
||||
params["objective"] = "binary:logistic"
|
||||
params.setdefault('num_class', self.n_classes_)
|
||||
|
||||
evals = _evaluation_matrices(self.client,
|
||||
eval_set, sample_weight_eval_set)
|
||||
|
||||
@ -109,3 +109,16 @@ def test_classifier(client):
|
||||
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||
assert len(list(history['validation_0'])) == 1
|
||||
assert len(history['validation_0']['merror']) == 2
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
|
||||
# Test with dataframe.
|
||||
X_d = dd.from_dask_array(X)
|
||||
y_d = dd.from_dask_array(y)
|
||||
classifier.fit(X_d, y_d)
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
prediction = classifier.predict(X_d)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user