Convert `DaskXGBClassifier.classes_` to an array (#8452)

---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Scott Gustafson
2023-04-26 14:23:35 -04:00
committed by GitHub
parent 0e7377ba9c
commit 353ed5339d
2 changed files with 26 additions and 1 deletions

View File

@@ -73,6 +73,7 @@ from .core import (
_deprecate_positional_args,
_expect,
)
from .data import _is_cudf_ser, _is_cupy_array
from .sklearn import (
XGBClassifier,
XGBClassifierBase,
@@ -1894,10 +1895,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa
)
# pylint: disable=attribute-defined-outside-init
if isinstance(y, (da.Array)):
if isinstance(y, da.Array):
self.classes_ = await self.client.compute(da.unique(y))
else:
self.classes_ = await self.client.compute(y.drop_duplicates())
if _is_cudf_ser(self.classes_):
self.classes_ = self.classes_.to_cupy()
if _is_cupy_array(self.classes_):
self.classes_ = self.classes_.get()
self.classes_ = numpy.array(self.classes_)
self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2: