Support scipy sparse in dask. (#7457)

This commit is contained in:
Jiaming Yuan
2021-11-23 16:45:36 +08:00
committed by GitHub
parent 5262e933f7
commit b124a27f57
3 changed files with 33 additions and 7 deletions

View File

@@ -112,12 +112,10 @@ except pkg_resources.DistributionNotFound:
try:
import sparse
import scipy.sparse as scipy_sparse
from scipy.sparse import csr_matrix as scipy_csr
SCIPY_INSTALLED = True
except ImportError:
sparse = False
scipy_sparse = False
scipy_csr: Any = object
SCIPY_INSTALLED = False

View File

@@ -33,7 +33,7 @@ from . import rabit, config
from .callback import TrainingCallback
from .compat import LazyLoader
from .compat import sparse, scipy_sparse
from .compat import scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import lazy_isinstance
@@ -186,10 +186,13 @@ def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
'''To be replaced with dask builtin.'''
if isinstance(value[0], numpy.ndarray):
return numpy.concatenate(value, axis=0)
if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix):
return scipy_sparse.vstack(value, format="csr")
if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix):
return scipy_sparse.vstack(value, format="csc")
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
# other sparse format will be converted to CSR.
return scipy_sparse.vstack(value, format='csr')
if sparse and isinstance(value[0], sparse.SparseArray):
return sparse.concatenate(value, axis=0)
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
return pandas_concat(value, axis=0)
if lazy_isinstance(value[0], 'cudf.core.dataframe', 'DataFrame') or \