diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 796608d1d..e6e392d3f 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -632,7 +632,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): params = self.get_xgb_params() # pylint: disable=attribute-defined-outside-init - self.classes_ = da.unique(y).compute() + if isinstance(y, (da.Array)): + self.classes_ = da.unique(y).compute() + else: + self.classes_ = y.drop_duplicates().compute() self.n_classes_ = len(self.classes_) if self.n_classes_ > 2: @@ -640,7 +643,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): params['num_class'] = self.n_classes_ else: params["objective"] = "binary:logistic" - params.setdefault('num_class', self.n_classes_) evals = _evaluation_matrices(self.client, eval_set, sample_weight_eval_set) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index f2ad44080..0b04e2822 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -109,3 +109,16 @@ def test_classifier(client): assert list(history['validation_0'].keys())[0] == 'merror' assert len(list(history['validation_0'])) == 1 assert len(history['validation_0']['merror']) == 2 + + assert classifier.n_classes_ == 10 + + # Test with dataframe. + X_d = dd.from_dask_array(X) + y_d = dd.from_dask_array(y) + classifier.fit(X_d, y_d) + + assert classifier.n_classes_ == 10 + prediction = classifier.predict(X_d) + + assert prediction.ndim == 1 + assert prediction.shape[0] == kRows