Support more scipy types. (#9881)
This commit is contained in:
@@ -57,12 +57,23 @@ def _check_data_shape(data: DataType) -> None:
|
||||
raise ValueError("Please reshape the input data into 2-dimensional matrix.")
|
||||
|
||||
|
||||
def _is_scipy_csr(data: DataType) -> bool:
|
||||
def is_scipy_csr(data: DataType) -> bool:
|
||||
"""Predicate for scipy CSR input."""
|
||||
is_array = False
|
||||
is_matrix = False
|
||||
try:
|
||||
import scipy.sparse
|
||||
from scipy.sparse import csr_array
|
||||
|
||||
is_array = isinstance(data, csr_array)
|
||||
except ImportError:
|
||||
return False
|
||||
return isinstance(data, scipy.sparse.csr_matrix)
|
||||
pass
|
||||
try:
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
is_matrix = isinstance(data, csr_matrix)
|
||||
except ImportError:
|
||||
pass
|
||||
return is_array or is_matrix
|
||||
|
||||
|
||||
def _array_interface_dict(data: np.ndarray) -> dict:
|
||||
@@ -135,12 +146,23 @@ def _from_scipy_csr(
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
def _is_scipy_csc(data: DataType) -> bool:
|
||||
def is_scipy_csc(data: DataType) -> bool:
|
||||
"""Predicate for scipy CSC input."""
|
||||
is_array = False
|
||||
is_matrix = False
|
||||
try:
|
||||
import scipy.sparse
|
||||
from scipy.sparse import csc_array
|
||||
|
||||
is_array = isinstance(data, csc_array)
|
||||
except ImportError:
|
||||
return False
|
||||
return isinstance(data, scipy.sparse.csc_matrix)
|
||||
pass
|
||||
try:
|
||||
from scipy.sparse import csc_matrix
|
||||
|
||||
is_matrix = isinstance(data, csc_matrix)
|
||||
except ImportError:
|
||||
pass
|
||||
return is_array or is_matrix
|
||||
|
||||
|
||||
def _from_scipy_csc(
|
||||
@@ -171,12 +193,23 @@ def _from_scipy_csc(
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
def _is_scipy_coo(data: DataType) -> bool:
|
||||
def is_scipy_coo(data: DataType) -> bool:
|
||||
"""Predicate for scipy COO input."""
|
||||
is_array = False
|
||||
is_matrix = False
|
||||
try:
|
||||
import scipy.sparse
|
||||
from scipy.sparse import coo_array
|
||||
|
||||
is_array = isinstance(data, coo_array)
|
||||
except ImportError:
|
||||
return False
|
||||
return isinstance(data, scipy.sparse.coo_matrix)
|
||||
pass
|
||||
try:
|
||||
from scipy.sparse import coo_matrix
|
||||
|
||||
is_matrix = isinstance(data, coo_matrix)
|
||||
except ImportError:
|
||||
pass
|
||||
return is_array or is_matrix
|
||||
|
||||
|
||||
def _is_np_array_like(data: DataType) -> bool:
|
||||
@@ -1138,15 +1171,15 @@ def dispatch_data_backend(
|
||||
"""Dispatch data for DMatrix."""
|
||||
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
||||
_check_data_shape(data)
|
||||
if _is_scipy_csr(data):
|
||||
if is_scipy_csr(data):
|
||||
return _from_scipy_csr(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_scipy_csc(data):
|
||||
if is_scipy_csc(data):
|
||||
return _from_scipy_csc(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_scipy_coo(data):
|
||||
if is_scipy_coo(data):
|
||||
return _from_scipy_csr(
|
||||
data.tocsr(),
|
||||
missing,
|
||||
@@ -1396,9 +1429,15 @@ def _proxy_transform(
|
||||
if _is_np_array_like(data):
|
||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||
return data, None, feature_names, feature_types
|
||||
if _is_scipy_csr(data):
|
||||
if is_scipy_csr(data):
|
||||
data = transform_scipy_sparse(data, True)
|
||||
return data, None, feature_names, feature_types
|
||||
if is_scipy_csc(data):
|
||||
data = transform_scipy_sparse(data.tocsr(), True)
|
||||
return data, None, feature_names, feature_types
|
||||
if is_scipy_coo(data):
|
||||
data = transform_scipy_sparse(data.tocsr(), True)
|
||||
return data, None, feature_names, feature_types
|
||||
if _is_pandas_series(data):
|
||||
import pandas as pd
|
||||
|
||||
@@ -1451,7 +1490,7 @@ def dispatch_proxy_set_data(
|
||||
_check_data_shape(data)
|
||||
proxy._set_data_from_array(data) # pylint: disable=W0212
|
||||
return
|
||||
if _is_scipy_csr(data):
|
||||
if is_scipy_csr(data):
|
||||
proxy._set_data_from_csr(data) # pylint: disable=W0212
|
||||
return
|
||||
raise err
|
||||
|
||||
Reference in New Issue
Block a user