Fix GPU quantile distributed test. (#8076)

This commit is contained in:
Jiaming Yuan 2022-07-16 11:40:53 +08:00 committed by GitHub
parent 8fccc3c4ad
commit 7a5586f3db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -499,14 +499,19 @@ class TestDistributedGPU:
for arg in rabit_args: for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port_env = arg.decode('utf-8') port_env = arg.decode('utf-8')
port_env = arg.decode('utf-8')
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8")
port = port_env.split('=') port = port_env.split('=')
env = os.environ.copy() env = os.environ.copy()
env[port[0]] = port[1] env[port[0]] = port[1]
uri = uri_env.split("=")
env[uri[0]] = uri[1]
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE) return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
with Client(local_cuda_cluster) as client: with Client(local_cuda_cluster) as client:
workers = _get_client_workers(client) workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, workers, None, client) rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client)
futures = client.map(runit, futures = client.map(runit,
workers, workers,
pure=False, pure=False,