[CI] Fix Dask Pytest fixture (#6024)
This commit is contained in:
parent
d240463b38
commit
14d5ce712c
@ -32,14 +32,12 @@ def local_cuda_cluster(request, pytestconfig):
|
||||
raise ImportError('The --use-rmm-pool option requires the RMM package')
|
||||
import rmm
|
||||
from dask_cuda.utils import get_n_gpus
|
||||
rmm.reinitialize()
|
||||
kwargs['rmm_pool_size'] = '2GB'
|
||||
if tm.no_dask_cuda()['condition']:
|
||||
raise ImportError('The local_cuda_cluster fixture requires dask_cuda package')
|
||||
from dask_cuda import LocalCUDACluster
|
||||
cluster = LocalCUDACluster(**kwargs)
|
||||
yield cluster
|
||||
cluster.close()
|
||||
with LocalCUDACluster(**kwargs) as cluster:
|
||||
yield cluster
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user