diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 88bd1c819..35c5c009f 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -73,6 +73,7 @@ from .core import ( _deprecate_positional_args, _expect, ) +from .data import _is_cudf_ser, _is_cupy_array from .sklearn import ( XGBClassifier, XGBClassifierBase, @@ -1894,10 +1895,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa ) # pylint: disable=attribute-defined-outside-init - if isinstance(y, (da.Array)): + if isinstance(y, da.Array): self.classes_ = await self.client.compute(da.unique(y)) else: self.classes_ = await self.client.compute(y.drop_duplicates()) + if _is_cudf_ser(self.classes_): + self.classes_ = self.classes_.to_cupy() + if _is_cupy_array(self.classes_): + self.classes_ = self.classes_.get() + self.classes_ = numpy.array(self.classes_) self.n_classes_ = len(self.classes_) if self.n_classes_ > 2: diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index 0bf952025..5e9303a46 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -192,6 +192,25 @@ def deterministic_repartition( return X, y, m +@pytest.mark.parametrize("to_frame", [True, False]) +def test_xgbclassifier_classes_type_and_value(to_frame: bool, client: "Client"): + X, y = make_classification(n_samples=1000, n_features=4, random_state=123) + if to_frame: + import pandas as pd + feats = [f"var_{i}" for i in range(4)] + df = pd.DataFrame(X, columns=feats) + df["target"] = y + df = dd.from_pandas(df, npartitions=1) + X, y = df[feats], df["target"] + else: + X = da.from_array(X) + y = da.from_array(y) + + est = xgb.dask.DaskXGBClassifier(n_estimators=10).fit(X, y) + assert isinstance(est.classes_, np.ndarray) + np.testing.assert_array_equal(est.classes_, np.array([0, 1])) + + def test_from_dask_dataframe() -> None: with LocalCluster(n_workers=kWorkers, dashboard_address=":0") as cluster: with Client(cluster) as client: