From a90d204942d3e623bd31b757160f5f8b897f03cc Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 23 Sep 2023 03:13:48 +0800 Subject: [PATCH] Use array interface for testing numpy arrays. (#9602) --- python-package/xgboost/core.py | 3 ++- python-package/xgboost/data.py | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 486cee514..f94e60321 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2434,6 +2434,7 @@ class Booster: _is_cudf_df, _is_cupy_array, _is_list, + _is_np_array_like, _is_pandas_df, _is_pandas_series, _is_tuple, @@ -2463,7 +2464,7 @@ class Booster: f"got {data.shape[1]}" ) - if isinstance(data, np.ndarray): + if _is_np_array_like(data): from .data import _ensure_np_dtype data, _ = _ensure_np_dtype(data, data.dtype) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 428e48d10..0022a17d4 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -164,8 +164,8 @@ def _is_scipy_coo(data: DataType) -> bool: return isinstance(data, scipy.sparse.coo_matrix) -def _is_numpy_array(data: DataType) -> bool: - return isinstance(data, (np.ndarray, np.matrix)) +def _is_np_array_like(data: DataType) -> bool: + return hasattr(data, "__array_interface__") def _ensure_np_dtype( @@ -1071,7 +1071,7 @@ def dispatch_data_backend( return _from_scipy_csr( data.tocsr(), missing, threads, feature_names, feature_types ) - if _is_numpy_array(data): + if _is_np_array_like(data): return _from_numpy_array( data, missing, threads, feature_names, feature_types, data_split_mode ) @@ -1214,7 +1214,7 @@ def dispatch_meta_backend( if _is_tuple(data): _meta_from_tuple(data, name, dtype, handle) return - if _is_numpy_array(data): + if _is_np_array_like(data): _meta_from_numpy(data, name, dtype, handle) return if _is_pandas_df(data): @@ -1301,7 +1301,7 @@ def _proxy_transform( return _transform_dlpack(data), None, feature_names, feature_types if _is_list(data) or _is_tuple(data): data = np.array(data) - if _is_numpy_array(data): + if _is_np_array_like(data): data, _ = _ensure_np_dtype(data, data.dtype) return data, None, feature_names, feature_types if _is_scipy_csr(data): @@ -1351,7 +1351,7 @@ def dispatch_proxy_set_data( if not allow_host: raise err - if _is_numpy_array(data): + if _is_np_array_like(data): _check_data_shape(data) proxy._set_data_from_array(data) # pylint: disable=W0212 return