[dask] Deterministic rank assignment. (#8018) (#8165)

This commit is contained in:
Jiaming Yuan 2022-08-15 15:18:26 +08:00 committed by GitHub
parent 2e6444b342
commit b18c984035
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 19 deletions

View File

@ -177,9 +177,11 @@ def _try_start_tracker(
use_logger=False,
)
else:
assert isinstance(addrs[0], str) or addrs[0] is None
addr = addrs[0]
assert isinstance(addr, str) or addr is None
host_ip = get_host_ip(addr)
rabit_context = RabitTracker(
host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
)
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
@ -229,8 +231,16 @@ class RabitContext:
def __init__(self, args: List[bytes]) -> None:
self.args = args
worker = distributed.get_worker()
with distributed.worker_client() as client:
info = client.scheduler_info()
w = info["workers"][worker.address]
wid = w["id"]
# We use task ID for rank assignment which makes the RABIT rank consistent (but
# not the same as task ID is string and "10" is sorted before "2") with dask
# worker ID. This outsources the rank assignment to dask and prevents
# non-deterministic issue.
self.args.append(
("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode()
(f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode()
)
def __enter__(self) -> None:
@ -870,6 +880,8 @@ async def _get_rabit_args(
except Exception: # pylint: disable=broad-except
sched_addr = None
# make sure all workers are online so that we can obtain reliable scheduler_info
client.wait_for_workers(n_workers)
env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr
)

View File

@ -32,15 +32,15 @@ class ExSocket:
chunk = self.sock.recv(min(nbytes - nread, 1024))
nread += len(chunk)
res.append(chunk)
return b''.join(res)
return b"".join(res)
def recvint(self) -> int:
"""Receive an integer of 32 bytes"""
return struct.unpack('@i', self.recvall(4))[0]
return struct.unpack("@i", self.recvall(4))[0]
def sendint(self, value: int) -> None:
"""Send an integer of 32 bytes"""
self.sock.sendall(struct.pack('@i', value))
self.sock.sendall(struct.pack("@i", value))
def sendstr(self, value: str) -> None:
"""Send a Python string"""
@ -69,6 +69,7 @@ def get_family(addr: str) -> int:
class WorkerEntry:
"""Hanlder to each worker."""
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
worker = ExSocket(sock)
self.sock = worker
@ -78,7 +79,7 @@ class WorkerEntry:
worker.sendint(MAGIC_NUM)
self.rank = worker.recvint()
self.world_size = worker.recvint()
self.jobid = worker.recvstr()
self.task_id = worker.recvstr()
self.cmd = worker.recvstr()
self.wait_accept = 0
self.port: Optional[int] = None
@ -96,8 +97,8 @@ class WorkerEntry:
"""Get the rank of current entry."""
if self.rank >= 0:
return self.rank
if self.jobid != 'NULL' and self.jobid in job_map:
return job_map[self.jobid]
if self.task_id != "NULL" and self.task_id in job_map:
return job_map[self.task_id]
return -1
def assign_rank(
@ -180,7 +181,12 @@ class RabitTracker:
"""
def __init__(
self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False
self,
host_ip: str,
n_workers: int,
port: int = 0,
use_logger: bool = False,
sortby: str = "host",
) -> None:
"""A Python implementation of RABIT tracker.
@ -190,6 +196,13 @@ class RabitTracker:
Use logging.info for tracker print command. When set to False, Python print
function is used instead.
sortby:
How to sort the workers for rank assignment. The default is host, but users
can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
deterministic rank assignment. Available options are:
- host
- task
"""
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
sock.bind((host_ip, port))
@ -200,6 +213,7 @@ class RabitTracker:
self.thread: Optional[Thread] = None
self.n_workers = n_workers
self._use_logger = use_logger
self._sortby = sortby
logging.info("start listen on %s:%d", host_ip, self.port)
def __del__(self) -> None:
@ -223,7 +237,7 @@ class RabitTracker:
get environment variables for workers
can be passed in as args or envs
"""
return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port}
return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
tree_map: _TreeMap = {}
@ -296,8 +310,16 @@ class RabitTracker:
parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_
def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]:
if self._sortby == "host":
pending.sort(key=lambda s: s.host)
elif self._sortby == "task":
pending.sort(key=lambda s: s.task_id)
return pending
def accept_workers(self, n_workers: int) -> None:
"""Wait for all workers to connect to the tracker."""
# set of nodes that finishes the job
shutdown: Dict[int, WorkerEntry] = {}
# set of nodes that is waiting for connections
@ -341,27 +363,32 @@ class RabitTracker:
assert todo_nodes
pending.append(s)
if len(pending) == len(todo_nodes):
pending.sort(key=lambda x: x.host)
pending = self._sort_pending(pending)
for s in pending:
rank = todo_nodes.pop(0)
if s.jobid != 'NULL':
job_map[s.jobid] = rank
if s.task_id != "NULL":
job_map[s.task_id] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
logging.debug('Received %s signal from %s; assign rank %d',
s.cmd, s.host, s.rank)
logging.debug(
"Received %s signal from %s; assign rank %d",
s.cmd,
s.host,
s.rank,
)
if not todo_nodes:
logging.info('@tracker All of %d nodes getting started', n_workers)
logging.info("@tracker All of %d nodes getting started", n_workers)
else:
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
logging.debug('Received %s signal from %d', s.cmd, s.rank)
logging.debug("Received %s signal from %d", s.cmd, s.rank)
if s.wait_accept > 0:
wait_conn[rank] = s
logging.info('@tracker All nodes finishes job')
logging.info("@tracker All nodes finishes job")
def start(self, n_workers: int) -> None:
"""Strat the tracker, it will wait for `n_workers` to connect."""
def run() -> None:
self.accept_workers(n_workers)

View File

@ -4,6 +4,7 @@ import pytest
import testing as tm
import numpy as np
import sys
import re
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -59,3 +60,34 @@ def test_rabit_ops():
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
from test_with_dask import _get_client_workers
def local_test(worker_id):
with xgb.dask.RabitContext(args):
for val in args:
sval = val.decode("utf-8")
if sval.startswith("DMLC_TASK_ID"):
task_id = sval
break
matched = re.search(".*-([0-9]).*", task_id)
rank = xgb.rabit.get_rank()
# As long as the number of workers is lesser than 10, rank and worker id
# should be the same
assert rank == int(matched.group(1))
with LocalCluster(n_workers=8) as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
args = client.sync(
xgb.dask._get_rabit_args,
len(workers),
None,
client,
)
futures = client.map(local_test, range(len(workers)), workers=workers)
client.gather(futures)