Check cupy lazily. (#7752)

This commit is contained in:
Jiaming Yuan 2022-03-26 06:09:58 +08:00 committed by GitHub
parent af0cf88921
commit b3ba0e8708
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 12 deletions

View File

@ -17,7 +17,7 @@ from inspect import signature, Parameter
import numpy as np
import scipy.sparse
from .compat import STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED, lazy_isinstance
from .compat import STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED
from .libpath import find_lib_path
from ._typing import (
CStrPptr,
@ -2080,8 +2080,13 @@ class Booster:
f"got {data.shape[1]}"
)
from .data import _is_pandas_df, _transform_pandas_df, _is_cudf_df
from .data import _array_interface
from .data import (
_is_pandas_df,
_transform_pandas_df,
_is_cudf_df,
_is_cupy_array,
_array_interface,
)
enable_categorical = _has_categorical(self, data)
if _is_pandas_df(data):
data, _, _ = _transform_pandas_df(data, enable_categorical)
@ -2118,9 +2123,7 @@ class Booster:
)
)
return _prediction_output(shape, dims, preds, False)
if lazy_isinstance(data, "cupy.core.core", "ndarray") or lazy_isinstance(
data, "cupy._core.core", "ndarray"
):
if _is_cupy_array(data):
from .data import _transform_cupy_array
data = _transform_cupy_array(data)

View File

@ -688,12 +688,10 @@ def _is_cudf_ser(data):
return isinstance(data, cudf.Series)
def _is_cupy_array(data):
try:
import cupy
except ImportError:
return False
return isinstance(data, cupy.ndarray)
def _is_cupy_array(data: Any) -> bool:
return lazy_isinstance(data, "cupy.core.core", "ndarray") or lazy_isinstance(
data, "cupy._core.core", "ndarray"
)
def _transform_cupy_array(data):