Use array interface for testing numpy arrays. (#9602)
This commit is contained in:
parent
bbf5b9ee57
commit
a90d204942
@ -2434,6 +2434,7 @@ class Booster:
|
|||||||
_is_cudf_df,
|
_is_cudf_df,
|
||||||
_is_cupy_array,
|
_is_cupy_array,
|
||||||
_is_list,
|
_is_list,
|
||||||
|
_is_np_array_like,
|
||||||
_is_pandas_df,
|
_is_pandas_df,
|
||||||
_is_pandas_series,
|
_is_pandas_series,
|
||||||
_is_tuple,
|
_is_tuple,
|
||||||
@ -2463,7 +2464,7 @@ class Booster:
|
|||||||
f"got {data.shape[1]}"
|
f"got {data.shape[1]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(data, np.ndarray):
|
if _is_np_array_like(data):
|
||||||
from .data import _ensure_np_dtype
|
from .data import _ensure_np_dtype
|
||||||
|
|
||||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||||
|
|||||||
@ -164,8 +164,8 @@ def _is_scipy_coo(data: DataType) -> bool:
|
|||||||
return isinstance(data, scipy.sparse.coo_matrix)
|
return isinstance(data, scipy.sparse.coo_matrix)
|
||||||
|
|
||||||
|
|
||||||
def _is_numpy_array(data: DataType) -> bool:
|
def _is_np_array_like(data: DataType) -> bool:
|
||||||
return isinstance(data, (np.ndarray, np.matrix))
|
return hasattr(data, "__array_interface__")
|
||||||
|
|
||||||
|
|
||||||
def _ensure_np_dtype(
|
def _ensure_np_dtype(
|
||||||
@ -1071,7 +1071,7 @@ def dispatch_data_backend(
|
|||||||
return _from_scipy_csr(
|
return _from_scipy_csr(
|
||||||
data.tocsr(), missing, threads, feature_names, feature_types
|
data.tocsr(), missing, threads, feature_names, feature_types
|
||||||
)
|
)
|
||||||
if _is_numpy_array(data):
|
if _is_np_array_like(data):
|
||||||
return _from_numpy_array(
|
return _from_numpy_array(
|
||||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||||
)
|
)
|
||||||
@ -1214,7 +1214,7 @@ def dispatch_meta_backend(
|
|||||||
if _is_tuple(data):
|
if _is_tuple(data):
|
||||||
_meta_from_tuple(data, name, dtype, handle)
|
_meta_from_tuple(data, name, dtype, handle)
|
||||||
return
|
return
|
||||||
if _is_numpy_array(data):
|
if _is_np_array_like(data):
|
||||||
_meta_from_numpy(data, name, dtype, handle)
|
_meta_from_numpy(data, name, dtype, handle)
|
||||||
return
|
return
|
||||||
if _is_pandas_df(data):
|
if _is_pandas_df(data):
|
||||||
@ -1301,7 +1301,7 @@ def _proxy_transform(
|
|||||||
return _transform_dlpack(data), None, feature_names, feature_types
|
return _transform_dlpack(data), None, feature_names, feature_types
|
||||||
if _is_list(data) or _is_tuple(data):
|
if _is_list(data) or _is_tuple(data):
|
||||||
data = np.array(data)
|
data = np.array(data)
|
||||||
if _is_numpy_array(data):
|
if _is_np_array_like(data):
|
||||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||||
return data, None, feature_names, feature_types
|
return data, None, feature_names, feature_types
|
||||||
if _is_scipy_csr(data):
|
if _is_scipy_csr(data):
|
||||||
@ -1351,7 +1351,7 @@ def dispatch_proxy_set_data(
|
|||||||
if not allow_host:
|
if not allow_host:
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
if _is_numpy_array(data):
|
if _is_np_array_like(data):
|
||||||
_check_data_shape(data)
|
_check_data_shape(data)
|
||||||
proxy._set_data_from_array(data) # pylint: disable=W0212
|
proxy._set_data_from_array(data) # pylint: disable=W0212
|
||||||
return
|
return
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user