Use array interface for testing numpy arrays. (#9602)

This commit is contained in:
Jiaming Yuan 2023-09-23 03:13:48 +08:00 committed by GitHub
parent bbf5b9ee57
commit a90d204942
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 7 deletions

View File

@ -2434,6 +2434,7 @@ class Booster:
_is_cudf_df, _is_cudf_df,
_is_cupy_array, _is_cupy_array,
_is_list, _is_list,
_is_np_array_like,
_is_pandas_df, _is_pandas_df,
_is_pandas_series, _is_pandas_series,
_is_tuple, _is_tuple,
@ -2463,7 +2464,7 @@ class Booster:
f"got {data.shape[1]}" f"got {data.shape[1]}"
) )
if isinstance(data, np.ndarray): if _is_np_array_like(data):
from .data import _ensure_np_dtype from .data import _ensure_np_dtype
data, _ = _ensure_np_dtype(data, data.dtype) data, _ = _ensure_np_dtype(data, data.dtype)

View File

@ -164,8 +164,8 @@ def _is_scipy_coo(data: DataType) -> bool:
return isinstance(data, scipy.sparse.coo_matrix) return isinstance(data, scipy.sparse.coo_matrix)
def _is_numpy_array(data: DataType) -> bool: def _is_np_array_like(data: DataType) -> bool:
return isinstance(data, (np.ndarray, np.matrix)) return hasattr(data, "__array_interface__")
def _ensure_np_dtype( def _ensure_np_dtype(
@ -1071,7 +1071,7 @@ def dispatch_data_backend(
return _from_scipy_csr( return _from_scipy_csr(
data.tocsr(), missing, threads, feature_names, feature_types data.tocsr(), missing, threads, feature_names, feature_types
) )
if _is_numpy_array(data): if _is_np_array_like(data):
return _from_numpy_array( return _from_numpy_array(
data, missing, threads, feature_names, feature_types, data_split_mode data, missing, threads, feature_names, feature_types, data_split_mode
) )
@ -1214,7 +1214,7 @@ def dispatch_meta_backend(
if _is_tuple(data): if _is_tuple(data):
_meta_from_tuple(data, name, dtype, handle) _meta_from_tuple(data, name, dtype, handle)
return return
if _is_numpy_array(data): if _is_np_array_like(data):
_meta_from_numpy(data, name, dtype, handle) _meta_from_numpy(data, name, dtype, handle)
return return
if _is_pandas_df(data): if _is_pandas_df(data):
@ -1301,7 +1301,7 @@ def _proxy_transform(
return _transform_dlpack(data), None, feature_names, feature_types return _transform_dlpack(data), None, feature_names, feature_types
if _is_list(data) or _is_tuple(data): if _is_list(data) or _is_tuple(data):
data = np.array(data) data = np.array(data)
if _is_numpy_array(data): if _is_np_array_like(data):
data, _ = _ensure_np_dtype(data, data.dtype) data, _ = _ensure_np_dtype(data, data.dtype)
return data, None, feature_names, feature_types return data, None, feature_names, feature_types
if _is_scipy_csr(data): if _is_scipy_csr(data):
@ -1351,7 +1351,7 @@ def dispatch_proxy_set_data(
if not allow_host: if not allow_host:
raise err raise err
if _is_numpy_array(data): if _is_np_array_like(data):
_check_data_shape(data) _check_data_shape(data)
proxy._set_data_from_array(data) # pylint: disable=W0212 proxy._set_data_from_array(data) # pylint: disable=W0212
return return