Support scipy sparse in dask. (#7457)
This commit is contained in:
parent
5262e933f7
commit
b124a27f57
@ -112,12 +112,10 @@ except pkg_resources.DistributionNotFound:
|
||||
|
||||
|
||||
try:
|
||||
import sparse
|
||||
import scipy.sparse as scipy_sparse
|
||||
from scipy.sparse import csr_matrix as scipy_csr
|
||||
SCIPY_INSTALLED = True
|
||||
except ImportError:
|
||||
sparse = False
|
||||
scipy_sparse = False
|
||||
scipy_csr: Any = object
|
||||
SCIPY_INSTALLED = False
|
||||
|
||||
@ -33,7 +33,7 @@ from . import rabit, config
|
||||
from .callback import TrainingCallback
|
||||
|
||||
from .compat import LazyLoader
|
||||
from .compat import sparse, scipy_sparse
|
||||
from .compat import scipy_sparse
|
||||
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
@ -186,10 +186,13 @@ def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
|
||||
'''To be replaced with dask builtin.'''
|
||||
if isinstance(value[0], numpy.ndarray):
|
||||
return numpy.concatenate(value, axis=0)
|
||||
if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix):
|
||||
return scipy_sparse.vstack(value, format="csr")
|
||||
if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix):
|
||||
return scipy_sparse.vstack(value, format="csc")
|
||||
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
|
||||
# other sparse format will be converted to CSR.
|
||||
return scipy_sparse.vstack(value, format='csr')
|
||||
if sparse and isinstance(value[0], sparse.SparseArray):
|
||||
return sparse.concatenate(value, axis=0)
|
||||
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
|
||||
return pandas_concat(value, axis=0)
|
||||
if lazy_isinstance(value[0], 'cudf.core.dataframe', 'DataFrame') or \
|
||||
|
||||
@ -7,7 +7,7 @@ import sys
|
||||
import numpy as np
|
||||
import scipy
|
||||
import json
|
||||
from typing import List, Tuple, Dict, Optional, Type, Any, Callable
|
||||
from typing import List, Tuple, Dict, Optional, Type, Any
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -149,6 +149,30 @@ def test_from_dask_array() -> None:
|
||||
assert np.all(single_node_predt == from_arr.compute())
|
||||
|
||||
|
||||
def test_dask_sparse(client: "Client") -> None:
|
||||
X_, y_ = make_classification(n_samples=1000, n_informative=5, n_classes=3)
|
||||
rng = np.random.default_rng(seed=0)
|
||||
idx = rng.integers(low=0, high=X_.shape[0], size=X_.shape[0] // 4)
|
||||
X_[idx, :] = np.nan
|
||||
|
||||
# numpy
|
||||
X, y = da.from_array(X_), da.from_array(y_)
|
||||
clf = xgb.dask.DaskXGBClassifier(tree_method="hist", n_estimators=10)
|
||||
clf.client = client
|
||||
clf.fit(X, y, eval_set=[(X, y)])
|
||||
dense_results = clf.evals_result()
|
||||
|
||||
# scipy sparse
|
||||
X, y = da.from_array(X_).map_blocks(scipy.sparse.csr_matrix), da.from_array(y_)
|
||||
clf = xgb.dask.DaskXGBClassifier(tree_method="hist", n_estimators=10)
|
||||
clf.client = client
|
||||
clf.fit(X, y, eval_set=[(X, y)])
|
||||
sparse_results = clf.evals_result()
|
||||
np.testing.assert_allclose(
|
||||
dense_results["validation_0"]["mlogloss"], sparse_results["validation_0"]["mlogloss"]
|
||||
)
|
||||
|
||||
|
||||
def test_dask_predict_shape_infer(client: "Client") -> None:
|
||||
X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3)
|
||||
X_ = dd.from_array(X, chunksize=100)
|
||||
@ -270,7 +294,8 @@ def run_boost_from_prediction(
|
||||
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
|
||||
from sklearn.datasets import load_breast_cancer, load_digits
|
||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
|
||||
X, y = dd.from_array(X_, chunksize=200), dd.from_array(y_, chunksize=200)
|
||||
|
||||
run_boost_from_prediction(X, y, tree_method, client)
|
||||
|
||||
X_, y_ = load_digits(return_X_y=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user