Convert `DaskXGBClassifier.classes_` to an array (#8452)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -73,6 +73,7 @@ from .core import (
|
||||
_deprecate_positional_args,
|
||||
_expect,
|
||||
)
|
||||
from .data import _is_cudf_ser, _is_cupy_array
|
||||
from .sklearn import (
|
||||
XGBClassifier,
|
||||
XGBClassifierBase,
|
||||
@@ -1894,10 +1895,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa
|
||||
)
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if isinstance(y, (da.Array)):
|
||||
if isinstance(y, da.Array):
|
||||
self.classes_ = await self.client.compute(da.unique(y))
|
||||
else:
|
||||
self.classes_ = await self.client.compute(y.drop_duplicates())
|
||||
if _is_cudf_ser(self.classes_):
|
||||
self.classes_ = self.classes_.to_cupy()
|
||||
if _is_cupy_array(self.classes_):
|
||||
self.classes_ = self.classes_.get()
|
||||
self.classes_ = numpy.array(self.classes_)
|
||||
self.n_classes_ = len(self.classes_)
|
||||
|
||||
if self.n_classes_ > 2:
|
||||
|
||||
Reference in New Issue
Block a user