Remove cupy.array_equal, since it's not compatible with cuPy 7.8 (#6528)
This commit is contained in:
parent
ca3da55de4
commit
380f6f4ab8
@ -852,14 +852,18 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
self.classes_ = cp.unique(y.values)
|
||||
self.n_classes_ = len(self.classes_)
|
||||
can_use_label_encoder = False
|
||||
if not cp.array_equal(self.classes_, cp.arange(self.n_classes_)):
|
||||
expected_classes = cp.arange(self.n_classes_)
|
||||
if (self.classes_.shape != expected_classes.shape or
|
||||
not (self.classes_ == expected_classes).all()):
|
||||
raise ValueError(label_encoding_check_error)
|
||||
elif _is_cupy_array(y):
|
||||
import cupy as cp # pylint: disable=E0401
|
||||
self.classes_ = cp.unique(y)
|
||||
self.n_classes_ = len(self.classes_)
|
||||
can_use_label_encoder = False
|
||||
if not cp.array_equal(self.classes_, cp.arange(self.n_classes_)):
|
||||
expected_classes = cp.arange(self.n_classes_)
|
||||
if (self.classes_.shape != expected_classes.shape or
|
||||
not (self.classes_ == expected_classes).all()):
|
||||
raise ValueError(label_encoding_check_error)
|
||||
else:
|
||||
self.classes_ = np.unique(y)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user