Check cupy lazily. (#7752)
This commit is contained in:
parent
af0cf88921
commit
b3ba0e8708
@ -17,7 +17,7 @@ from inspect import signature, Parameter
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
from .compat import STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED, lazy_isinstance
|
||||
from .compat import STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED
|
||||
from .libpath import find_lib_path
|
||||
from ._typing import (
|
||||
CStrPptr,
|
||||
@ -2080,8 +2080,13 @@ class Booster:
|
||||
f"got {data.shape[1]}"
|
||||
)
|
||||
|
||||
from .data import _is_pandas_df, _transform_pandas_df, _is_cudf_df
|
||||
from .data import _array_interface
|
||||
from .data import (
|
||||
_is_pandas_df,
|
||||
_transform_pandas_df,
|
||||
_is_cudf_df,
|
||||
_is_cupy_array,
|
||||
_array_interface,
|
||||
)
|
||||
enable_categorical = _has_categorical(self, data)
|
||||
if _is_pandas_df(data):
|
||||
data, _, _ = _transform_pandas_df(data, enable_categorical)
|
||||
@ -2118,9 +2123,7 @@ class Booster:
|
||||
)
|
||||
)
|
||||
return _prediction_output(shape, dims, preds, False)
|
||||
if lazy_isinstance(data, "cupy.core.core", "ndarray") or lazy_isinstance(
|
||||
data, "cupy._core.core", "ndarray"
|
||||
):
|
||||
if _is_cupy_array(data):
|
||||
from .data import _transform_cupy_array
|
||||
|
||||
data = _transform_cupy_array(data)
|
||||
|
||||
@ -688,12 +688,10 @@ def _is_cudf_ser(data):
|
||||
return isinstance(data, cudf.Series)
|
||||
|
||||
|
||||
def _is_cupy_array(data):
|
||||
try:
|
||||
import cupy
|
||||
except ImportError:
|
||||
return False
|
||||
return isinstance(data, cupy.ndarray)
|
||||
def _is_cupy_array(data: Any) -> bool:
|
||||
return lazy_isinstance(data, "cupy.core.core", "ndarray") or lazy_isinstance(
|
||||
data, "cupy._core.core", "ndarray"
|
||||
)
|
||||
|
||||
|
||||
def _transform_cupy_array(data):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user