Convert `DaskXGBClassifier.classes_` to an array (#8452)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
0e7377ba9c
commit
353ed5339d
@ -73,6 +73,7 @@ from .core import (
|
|||||||
_deprecate_positional_args,
|
_deprecate_positional_args,
|
||||||
_expect,
|
_expect,
|
||||||
)
|
)
|
||||||
|
from .data import _is_cudf_ser, _is_cupy_array
|
||||||
from .sklearn import (
|
from .sklearn import (
|
||||||
XGBClassifier,
|
XGBClassifier,
|
||||||
XGBClassifierBase,
|
XGBClassifierBase,
|
||||||
@ -1894,10 +1895,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa
|
|||||||
)
|
)
|
||||||
|
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if isinstance(y, (da.Array)):
|
if isinstance(y, da.Array):
|
||||||
self.classes_ = await self.client.compute(da.unique(y))
|
self.classes_ = await self.client.compute(da.unique(y))
|
||||||
else:
|
else:
|
||||||
self.classes_ = await self.client.compute(y.drop_duplicates())
|
self.classes_ = await self.client.compute(y.drop_duplicates())
|
||||||
|
if _is_cudf_ser(self.classes_):
|
||||||
|
self.classes_ = self.classes_.to_cupy()
|
||||||
|
if _is_cupy_array(self.classes_):
|
||||||
|
self.classes_ = self.classes_.get()
|
||||||
|
self.classes_ = numpy.array(self.classes_)
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
|
|
||||||
if self.n_classes_ > 2:
|
if self.n_classes_ > 2:
|
||||||
|
|||||||
@ -192,6 +192,25 @@ def deterministic_repartition(
|
|||||||
return X, y, m
|
return X, y, m
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("to_frame", [True, False])
|
||||||
|
def test_xgbclassifier_classes_type_and_value(to_frame: bool, client: "Client"):
|
||||||
|
X, y = make_classification(n_samples=1000, n_features=4, random_state=123)
|
||||||
|
if to_frame:
|
||||||
|
import pandas as pd
|
||||||
|
feats = [f"var_{i}" for i in range(4)]
|
||||||
|
df = pd.DataFrame(X, columns=feats)
|
||||||
|
df["target"] = y
|
||||||
|
df = dd.from_pandas(df, npartitions=1)
|
||||||
|
X, y = df[feats], df["target"]
|
||||||
|
else:
|
||||||
|
X = da.from_array(X)
|
||||||
|
y = da.from_array(y)
|
||||||
|
|
||||||
|
est = xgb.dask.DaskXGBClassifier(n_estimators=10).fit(X, y)
|
||||||
|
assert isinstance(est.classes_, np.ndarray)
|
||||||
|
np.testing.assert_array_equal(est.classes_, np.array([0, 1]))
|
||||||
|
|
||||||
|
|
||||||
def test_from_dask_dataframe() -> None:
|
def test_from_dask_dataframe() -> None:
|
||||||
with LocalCluster(n_workers=kWorkers, dashboard_address=":0") as cluster:
|
with LocalCluster(n_workers=kWorkers, dashboard_address=":0") as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user