diff --git a/tests/python-gpu/conftest.py b/tests/python-gpu/conftest.py index 1865ce529..e493108a3 100644 --- a/tests/python-gpu/conftest.py +++ b/tests/python-gpu/conftest.py @@ -32,14 +32,12 @@ def local_cuda_cluster(request, pytestconfig): raise ImportError('The --use-rmm-pool option requires the RMM package') import rmm from dask_cuda.utils import get_n_gpus - rmm.reinitialize() kwargs['rmm_pool_size'] = '2GB' if tm.no_dask_cuda()['condition']: raise ImportError('The local_cuda_cluster fixture requires dask_cuda package') from dask_cuda import LocalCUDACluster - cluster = LocalCUDACluster(**kwargs) - yield cluster - cluster.close() + with LocalCUDACluster(**kwargs) as cluster: + yield cluster def pytest_addoption(parser): parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool')