Support scipy sparse in dask. (#7457)

This commit is contained in:
Jiaming Yuan 2021-11-23 16:45:36 +08:00 committed by GitHub
parent 5262e933f7
commit b124a27f57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 7 deletions

View File

@ -112,12 +112,10 @@ except pkg_resources.DistributionNotFound:
try:
import sparse
import scipy.sparse as scipy_sparse
from scipy.sparse import csr_matrix as scipy_csr
SCIPY_INSTALLED = True
except ImportError:
sparse = False
scipy_sparse = False
scipy_csr: Any = object
SCIPY_INSTALLED = False

View File

@ -33,7 +33,7 @@ from . import rabit, config
from .callback import TrainingCallback
from .compat import LazyLoader
from .compat import sparse, scipy_sparse
from .compat import scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import lazy_isinstance
@ -186,10 +186,13 @@ def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
'''To be replaced with dask builtin.'''
if isinstance(value[0], numpy.ndarray):
return numpy.concatenate(value, axis=0)
if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix):
return scipy_sparse.vstack(value, format="csr")
if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix):
return scipy_sparse.vstack(value, format="csc")
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
# other sparse format will be converted to CSR.
return scipy_sparse.vstack(value, format='csr')
if sparse and isinstance(value[0], sparse.SparseArray):
return sparse.concatenate(value, axis=0)
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
return pandas_concat(value, axis=0)
if lazy_isinstance(value[0], 'cudf.core.dataframe', 'DataFrame') or \

View File

@ -7,7 +7,7 @@ import sys
import numpy as np
import scipy
import json
from typing import List, Tuple, Dict, Optional, Type, Any, Callable
from typing import List, Tuple, Dict, Optional, Type, Any
import asyncio
from functools import partial
from concurrent.futures import ThreadPoolExecutor
@ -149,6 +149,30 @@ def test_from_dask_array() -> None:
assert np.all(single_node_predt == from_arr.compute())
def test_dask_sparse(client: "Client") -> None:
X_, y_ = make_classification(n_samples=1000, n_informative=5, n_classes=3)
rng = np.random.default_rng(seed=0)
idx = rng.integers(low=0, high=X_.shape[0], size=X_.shape[0] // 4)
X_[idx, :] = np.nan
# numpy
X, y = da.from_array(X_), da.from_array(y_)
clf = xgb.dask.DaskXGBClassifier(tree_method="hist", n_estimators=10)
clf.client = client
clf.fit(X, y, eval_set=[(X, y)])
dense_results = clf.evals_result()
# scipy sparse
X, y = da.from_array(X_).map_blocks(scipy.sparse.csr_matrix), da.from_array(y_)
clf = xgb.dask.DaskXGBClassifier(tree_method="hist", n_estimators=10)
clf.client = client
clf.fit(X, y, eval_set=[(X, y)])
sparse_results = clf.evals_result()
np.testing.assert_allclose(
dense_results["validation_0"]["mlogloss"], sparse_results["validation_0"]["mlogloss"]
)
def test_dask_predict_shape_infer(client: "Client") -> None:
X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3)
X_ = dd.from_array(X, chunksize=100)
@ -270,7 +294,8 @@ def run_boost_from_prediction(
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
from sklearn.datasets import load_breast_cancer, load_digits
X_, y_ = load_breast_cancer(return_X_y=True)
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
X, y = dd.from_array(X_, chunksize=200), dd.from_array(y_, chunksize=200)
run_boost_from_prediction(X, y, tree_method, client)
X_, y_ = load_digits(return_X_y=True)