From 7a5586f3db8585dbec2b08f660b765ad5d2c1d2e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 16 Jul 2022 11:40:53 +0800 Subject: [PATCH] Fix GPU quantile distributed test. (#8076) --- tests/python-gpu/test_gpu_with_dask.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 8a947312e..dcb228adb 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -291,7 +291,7 @@ class TestDistributedGPU: valid_X = X valid_y = y cls = dxgb.DaskXGBClassifier(objective='binary:logistic', - tree_method='gpu_hist', + tree_method='gpu_hist', eval_metric='error', n_estimators=100) cls.client = client @@ -499,14 +499,19 @@ class TestDistributedGPU: for arg in rabit_args: if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): port_env = arg.decode('utf-8') + port_env = arg.decode('utf-8') + if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"): + uri_env = arg.decode("utf-8") port = port_env.split('=') env = os.environ.copy() env[port[0]] = port[1] + uri = uri_env.split("=") + env[uri[0]] = uri[1] return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE) with Client(local_cuda_cluster) as client: workers = _get_client_workers(client) - rabit_args = client.sync(dxgb._get_rabit_args, workers, None, client) + rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client) futures = client.map(runit, workers, pure=False,