Extract dask and spark test into distributed test. (#8395)
- Move test files. - Run spark and dask separately to prevent conflicts. - Gather common code into the testing module.
This commit is contained in:
@@ -1,43 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from xgboost import testing as tm # noqa
|
||||
from xgboost import testing as tm
|
||||
|
||||
|
||||
def has_rmm():
|
||||
try:
|
||||
import rmm
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
return tm.no_rmm()["condition"]
|
||||
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_rmm_pool(request, pytestconfig):
|
||||
if pytestconfig.getoption('--use-rmm-pool'):
|
||||
if not has_rmm():
|
||||
raise ImportError('The --use-rmm-pool option requires the RMM package')
|
||||
import rmm
|
||||
from dask_cuda.utils import get_n_gpus
|
||||
rmm.reinitialize(pool_allocator=True, initial_pool_size=1024*1024*1024,
|
||||
devices=list(range(get_n_gpus())))
|
||||
tm.setup_rmm_pool(request, pytestconfig)
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def local_cuda_client(request, pytestconfig):
|
||||
kwargs = {}
|
||||
if hasattr(request, 'param'):
|
||||
kwargs.update(request.param)
|
||||
if pytestconfig.getoption('--use-rmm-pool'):
|
||||
if not has_rmm():
|
||||
raise ImportError('The --use-rmm-pool option requires the RMM package')
|
||||
import rmm
|
||||
kwargs['rmm_pool_size'] = '2GB'
|
||||
if tm.no_dask_cuda()['condition']:
|
||||
raise ImportError('The local_cuda_cluster fixture requires dask_cuda package')
|
||||
from dask.distributed import Client
|
||||
from dask_cuda import LocalCUDACluster
|
||||
yield Client(LocalCUDACluster(**kwargs))
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool')
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
parser.addoption(
|
||||
"--use-rmm-pool", action="store_true", default=False, help="Use RMM pool"
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
@@ -53,13 +31,3 @@ def pytest_collection_modifyitems(config, items):
|
||||
for item in items:
|
||||
if any(item.nodeid.startswith(x) for x in blocklist):
|
||||
item.add_marker(skip_mark)
|
||||
|
||||
# mark dask tests as `mgpu`.
|
||||
mgpu_mark = pytest.mark.mgpu
|
||||
for item in items:
|
||||
if item.nodeid.startswith(
|
||||
"python-gpu/test_gpu_with_dask/test_gpu_with_dask.py"
|
||||
) or item.nodeid.startswith(
|
||||
"python-gpu/test_gpu_spark/test_gpu_spark.py"
|
||||
):
|
||||
item.add_marker(mgpu_mark)
|
||||
|
||||
Reference in New Issue
Block a user