Support scipy sparse in dask. (#7457)

This commit is contained in:
Jiaming Yuan
2021-11-23 16:45:36 +08:00
committed by GitHub
parent 5262e933f7
commit b124a27f57
3 changed files with 33 additions and 7 deletions

View File

@@ -7,7 +7,7 @@ import sys
import numpy as np
import scipy
import json
from typing import List, Tuple, Dict, Optional, Type, Any, Callable
from typing import List, Tuple, Dict, Optional, Type, Any
import asyncio
from functools import partial
from concurrent.futures import ThreadPoolExecutor
@@ -149,6 +149,30 @@ def test_from_dask_array() -> None:
assert np.all(single_node_predt == from_arr.compute())
def test_dask_sparse(client: "Client") -> None:
X_, y_ = make_classification(n_samples=1000, n_informative=5, n_classes=3)
rng = np.random.default_rng(seed=0)
idx = rng.integers(low=0, high=X_.shape[0], size=X_.shape[0] // 4)
X_[idx, :] = np.nan
# numpy
X, y = da.from_array(X_), da.from_array(y_)
clf = xgb.dask.DaskXGBClassifier(tree_method="hist", n_estimators=10)
clf.client = client
clf.fit(X, y, eval_set=[(X, y)])
dense_results = clf.evals_result()
# scipy sparse
X, y = da.from_array(X_).map_blocks(scipy.sparse.csr_matrix), da.from_array(y_)
clf = xgb.dask.DaskXGBClassifier(tree_method="hist", n_estimators=10)
clf.client = client
clf.fit(X, y, eval_set=[(X, y)])
sparse_results = clf.evals_result()
np.testing.assert_allclose(
dense_results["validation_0"]["mlogloss"], sparse_results["validation_0"]["mlogloss"]
)
def test_dask_predict_shape_infer(client: "Client") -> None:
X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3)
X_ = dd.from_array(X, chunksize=100)
@@ -270,7 +294,8 @@ def run_boost_from_prediction(
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
from sklearn.datasets import load_breast_cancer, load_digits
X_, y_ = load_breast_cancer(return_X_y=True)
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
X, y = dd.from_array(X_, chunksize=200), dd.from_array(y_, chunksize=200)
run_boost_from_prediction(X, y, tree_method, client)
X_, y_ = load_digits(return_X_y=True)