[dask] Deterministic rank assignment. (#8018) (#8165)

This commit is contained in:
Jiaming Yuan
2022-08-15 15:18:26 +08:00
committed by GitHub
parent 2e6444b342
commit b18c984035
3 changed files with 90 additions and 19 deletions

View File

@@ -4,6 +4,7 @@ import pytest
import testing as tm
import numpy as np
import sys
import re
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@@ -59,3 +60,34 @@ def test_rabit_ops():
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
from test_with_dask import _get_client_workers
def local_test(worker_id):
with xgb.dask.RabitContext(args):
for val in args:
sval = val.decode("utf-8")
if sval.startswith("DMLC_TASK_ID"):
task_id = sval
break
matched = re.search(".*-([0-9]).*", task_id)
rank = xgb.rabit.get_rank()
# As long as the number of workers is lesser than 10, rank and worker id
# should be the same
assert rank == int(matched.group(1))
with LocalCluster(n_workers=8) as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
args = client.sync(
xgb.dask._get_rabit_args,
len(workers),
None,
client,
)
futures = client.map(local_test, range(len(workers)), workers=workers)
client.gather(futures)