- Move test files. - Run spark and dask separately to prevent conflicts. - Gather common code into the testing module.
43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
from typing import Generator, Sequence
|
|
|
|
import pytest
|
|
|
|
from xgboost import testing as tm
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def setup_rmm_pool(request, pytestconfig: pytest.Config) -> None:
|
|
tm.setup_rmm_pool(request, pytestconfig)
|
|
|
|
|
|
@pytest.fixture(scope="class")
|
|
def local_cuda_client(request, pytestconfig: pytest.Config) -> Generator:
|
|
kwargs = {}
|
|
if hasattr(request, "param"):
|
|
kwargs.update(request.param)
|
|
if pytestconfig.getoption("--use-rmm-pool"):
|
|
if tm.no_rmm()["condition"]:
|
|
raise ImportError("The --use-rmm-pool option requires the RMM package")
|
|
import rmm
|
|
|
|
kwargs["rmm_pool_size"] = "2GB"
|
|
if tm.no_dask_cuda()["condition"]:
|
|
raise ImportError("The local_cuda_cluster fixture requires dask_cuda package")
|
|
from dask.distributed import Client
|
|
from dask_cuda import LocalCUDACluster
|
|
|
|
yield Client(LocalCUDACluster(**kwargs))
|
|
|
|
|
|
def pytest_addoption(parser: pytest.Parser) -> None:
|
|
parser.addoption(
|
|
"--use-rmm-pool", action="store_true", default=False, help="Use RMM pool"
|
|
)
|
|
|
|
|
|
def pytest_collection_modifyitems(config: pytest.Config, items: Sequence) -> None:
|
|
# mark dask tests as `mgpu`.
|
|
mgpu_mark = pytest.mark.mgpu
|
|
for item in items:
|
|
item.add_marker(mgpu_mark)
|