[dask] Deterministic rank assignment. (#8018)
This commit is contained in:
@@ -4,6 +4,7 @@ import pytest
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import sys
|
||||
import re
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
@@ -58,3 +59,34 @@ def test_rabit_ops():
|
||||
with LocalCluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_rabit_ops(client, n_workers)
|
||||
|
||||
|
||||
def test_rank_assignment() -> None:
|
||||
from distributed import Client, LocalCluster
|
||||
from test_with_dask import _get_client_workers
|
||||
|
||||
def local_test(worker_id):
|
||||
with xgb.dask.RabitContext(args):
|
||||
for val in args:
|
||||
sval = val.decode("utf-8")
|
||||
if sval.startswith("DMLC_TASK_ID"):
|
||||
task_id = sval
|
||||
break
|
||||
matched = re.search(".*-([0-9]).*", task_id)
|
||||
rank = xgb.rabit.get_rank()
|
||||
# As long as the number of workers is lesser than 10, rank and worker id
|
||||
# should be the same
|
||||
assert rank == int(matched.group(1))
|
||||
|
||||
with LocalCluster(n_workers=8) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = _get_client_workers(client)
|
||||
args = client.sync(
|
||||
xgb.dask._get_rabit_args,
|
||||
len(workers),
|
||||
None,
|
||||
client,
|
||||
)
|
||||
|
||||
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||
client.gather(futures)
|
||||
|
||||
Reference in New Issue
Block a user