Fix GPU quantile distributed test. (#8076)
This commit is contained in:
parent
8fccc3c4ad
commit
7a5586f3db
@ -291,7 +291,7 @@ class TestDistributedGPU:
|
||||
valid_X = X
|
||||
valid_y = y
|
||||
cls = dxgb.DaskXGBClassifier(objective='binary:logistic',
|
||||
tree_method='gpu_hist',
|
||||
tree_method='gpu_hist',
|
||||
eval_metric='error',
|
||||
n_estimators=100)
|
||||
cls.client = client
|
||||
@ -499,14 +499,19 @@ class TestDistributedGPU:
|
||||
for arg in rabit_args:
|
||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||
port_env = arg.decode('utf-8')
|
||||
port_env = arg.decode('utf-8')
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split('=')
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
uri = uri_env.split("=")
|
||||
env[uri[0]] = uri[1]
|
||||
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
||||
|
||||
with Client(local_cuda_cluster) as client:
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, None, client)
|
||||
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client)
|
||||
futures = client.map(runit,
|
||||
workers,
|
||||
pure=False,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user