Fix GPU quantile distributed test. (#8076)
This commit is contained in:
parent
8fccc3c4ad
commit
7a5586f3db
@ -499,14 +499,19 @@ class TestDistributedGPU:
|
|||||||
for arg in rabit_args:
|
for arg in rabit_args:
|
||||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||||
port_env = arg.decode('utf-8')
|
port_env = arg.decode('utf-8')
|
||||||
|
port_env = arg.decode('utf-8')
|
||||||
|
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||||
|
uri_env = arg.decode("utf-8")
|
||||||
port = port_env.split('=')
|
port = port_env.split('=')
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env[port[0]] = port[1]
|
env[port[0]] = port[1]
|
||||||
|
uri = uri_env.split("=")
|
||||||
|
env[uri[0]] = uri[1]
|
||||||
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
||||||
|
|
||||||
with Client(local_cuda_cluster) as client:
|
with Client(local_cuda_cluster) as client:
|
||||||
workers = _get_client_workers(client)
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, None, client)
|
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client)
|
||||||
futures = client.map(runit,
|
futures = client.map(runit,
|
||||||
workers,
|
workers,
|
||||||
pure=False,
|
pure=False,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user