Optionally skip cupy on windows. (#10611)

This commit is contained in:
Jiaming Yuan 2024-07-20 22:12:12 +08:00 committed by GitHub
parent 344ddeb9ca
commit 0846ad860c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 30 deletions

View File

@ -45,6 +45,7 @@ from xgboost.testing.data import (
get_cancer, get_cancer,
get_digits, get_digits,
get_sparse, get_sparse,
make_batches,
memory, memory,
) )
@ -161,7 +162,16 @@ def no_cudf() -> PytestSkip:
def no_cupy() -> PytestSkip: def no_cupy() -> PytestSkip:
return no_mod("cupy") skip_cupy = no_mod("cupy")
if not skip_cupy["condition"] and system() == "Windows":
import cupy as cp
# Cupy might run into issue on Windows due to missing compiler
try:
cp.array([1, 2, 3]).sum()
except Exception: # pylint: disable=broad-except
skip_cupy["condition"] = True
return skip_cupy
def no_dask_cudf() -> PytestSkip: def no_dask_cudf() -> PytestSkip:
@ -248,35 +258,6 @@ class IteratorForTest(xgb.core.DataIter):
return X, y, w return X, y, w
def make_batches( # pylint: disable=too-many-arguments,too-many-locals
n_samples_per_batch: int,
n_features: int,
n_batches: int,
use_cupy: bool = False,
*,
vary_size: bool = False,
random_state: int = 1994,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
X = []
y = []
w = []
if use_cupy:
import cupy
rng = cupy.random.RandomState(random_state)
else:
rng = np.random.RandomState(random_state)
for i in range(n_batches):
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
_X = rng.randn(n_samples, n_features)
_y = rng.randn(n_samples)
_w = rng.uniform(low=0, high=1, size=n_samples)
X.append(_X)
y.append(_y)
w.append(_w)
return X, y, w
def make_regression( def make_regression(
n_samples: int, n_features: int, use_cupy: bool n_samples: int, n_features: int, use_cupy: bool
) -> Tuple[ArrayLike, ArrayLike, ArrayLike]: ) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:

View File

@ -9,6 +9,7 @@ from typing import (
Callable, Callable,
Dict, Dict,
Generator, Generator,
List,
NamedTuple, NamedTuple,
Optional, Optional,
Tuple, Tuple,
@ -506,6 +507,36 @@ def get_mq2008(
) )
def make_batches( # pylint: disable=too-many-arguments,too-many-locals
n_samples_per_batch: int,
n_features: int,
n_batches: int,
use_cupy: bool = False,
*,
vary_size: bool = False,
random_state: int = 1994,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Make batches of dense data."""
X = []
y = []
w = []
if use_cupy:
import cupy # pylint: disable=import-error
rng = cupy.random.RandomState(random_state)
else:
rng = np.random.RandomState(random_state)
for i in range(n_batches):
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
_X = rng.randn(n_samples, n_features)
_y = rng.randn(n_samples)
_w = rng.uniform(low=0, high=1, size=n_samples)
X.append(_X)
y.append(_y)
w.append(_w)
return X, y, w
RelData = Tuple[sparse.csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]] RelData = Tuple[sparse.csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]]