Support scipy sparse in dask. (#7457)
This commit is contained in:
@@ -112,12 +112,10 @@ except pkg_resources.DistributionNotFound:
|
||||
|
||||
|
||||
try:
|
||||
import sparse
|
||||
import scipy.sparse as scipy_sparse
|
||||
from scipy.sparse import csr_matrix as scipy_csr
|
||||
SCIPY_INSTALLED = True
|
||||
except ImportError:
|
||||
sparse = False
|
||||
scipy_sparse = False
|
||||
scipy_csr: Any = object
|
||||
SCIPY_INSTALLED = False
|
||||
|
||||
@@ -33,7 +33,7 @@ from . import rabit, config
|
||||
from .callback import TrainingCallback
|
||||
|
||||
from .compat import LazyLoader
|
||||
from .compat import sparse, scipy_sparse
|
||||
from .compat import scipy_sparse
|
||||
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
@@ -186,10 +186,13 @@ def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
|
||||
'''To be replaced with dask builtin.'''
|
||||
if isinstance(value[0], numpy.ndarray):
|
||||
return numpy.concatenate(value, axis=0)
|
||||
if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix):
|
||||
return scipy_sparse.vstack(value, format="csr")
|
||||
if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix):
|
||||
return scipy_sparse.vstack(value, format="csc")
|
||||
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
|
||||
# other sparse format will be converted to CSR.
|
||||
return scipy_sparse.vstack(value, format='csr')
|
||||
if sparse and isinstance(value[0], sparse.SparseArray):
|
||||
return sparse.concatenate(value, axis=0)
|
||||
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
|
||||
return pandas_concat(value, axis=0)
|
||||
if lazy_isinstance(value[0], 'cudf.core.dataframe', 'DataFrame') or \
|
||||
|
||||
Reference in New Issue
Block a user