[dask] Deterministic rank assignment. (#8018) (#8165)

This commit is contained in:
Jiaming Yuan 2022-08-15 15:18:26 +08:00 committed by GitHub
parent 2e6444b342
commit b18c984035
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 19 deletions

View File

@ -177,9 +177,11 @@ def _try_start_tracker(
use_logger=False, use_logger=False,
) )
else: else:
assert isinstance(addrs[0], str) or addrs[0] is None addr = addrs[0]
assert isinstance(addr, str) or addr is None
host_ip = get_host_ip(addr)
rabit_context = RabitTracker( rabit_context = RabitTracker(
host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
) )
env.update(rabit_context.worker_envs()) env.update(rabit_context.worker_envs())
rabit_context.start(n_workers) rabit_context.start(n_workers)
@ -229,8 +231,16 @@ class RabitContext:
def __init__(self, args: List[bytes]) -> None: def __init__(self, args: List[bytes]) -> None:
self.args = args self.args = args
worker = distributed.get_worker() worker = distributed.get_worker()
with distributed.worker_client() as client:
info = client.scheduler_info()
w = info["workers"][worker.address]
wid = w["id"]
# We use task ID for rank assignment which makes the RABIT rank consistent (but
# not the same as task ID is string and "10" is sorted before "2") with dask
# worker ID. This outsources the rank assignment to dask and prevents
# non-deterministic issue.
self.args.append( self.args.append(
("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode() (f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode()
) )
def __enter__(self) -> None: def __enter__(self) -> None:
@ -870,6 +880,8 @@ async def _get_rabit_args(
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
sched_addr = None sched_addr = None
# make sure all workers are online so that we can obtain reliable scheduler_info
client.wait_for_workers(n_workers)
env = await client.run_on_scheduler( env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr _start_tracker, n_workers, sched_addr, user_addr
) )

View File

@ -32,15 +32,15 @@ class ExSocket:
chunk = self.sock.recv(min(nbytes - nread, 1024)) chunk = self.sock.recv(min(nbytes - nread, 1024))
nread += len(chunk) nread += len(chunk)
res.append(chunk) res.append(chunk)
return b''.join(res) return b"".join(res)
def recvint(self) -> int: def recvint(self) -> int:
"""Receive an integer of 32 bytes""" """Receive an integer of 32 bytes"""
return struct.unpack('@i', self.recvall(4))[0] return struct.unpack("@i", self.recvall(4))[0]
def sendint(self, value: int) -> None: def sendint(self, value: int) -> None:
"""Send an integer of 32 bytes""" """Send an integer of 32 bytes"""
self.sock.sendall(struct.pack('@i', value)) self.sock.sendall(struct.pack("@i", value))
def sendstr(self, value: str) -> None: def sendstr(self, value: str) -> None:
"""Send a Python string""" """Send a Python string"""
@ -69,6 +69,7 @@ def get_family(addr: str) -> int:
class WorkerEntry: class WorkerEntry:
"""Hanlder to each worker.""" """Hanlder to each worker."""
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]): def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
worker = ExSocket(sock) worker = ExSocket(sock)
self.sock = worker self.sock = worker
@ -78,7 +79,7 @@ class WorkerEntry:
worker.sendint(MAGIC_NUM) worker.sendint(MAGIC_NUM)
self.rank = worker.recvint() self.rank = worker.recvint()
self.world_size = worker.recvint() self.world_size = worker.recvint()
self.jobid = worker.recvstr() self.task_id = worker.recvstr()
self.cmd = worker.recvstr() self.cmd = worker.recvstr()
self.wait_accept = 0 self.wait_accept = 0
self.port: Optional[int] = None self.port: Optional[int] = None
@ -96,8 +97,8 @@ class WorkerEntry:
"""Get the rank of current entry.""" """Get the rank of current entry."""
if self.rank >= 0: if self.rank >= 0:
return self.rank return self.rank
if self.jobid != 'NULL' and self.jobid in job_map: if self.task_id != "NULL" and self.task_id in job_map:
return job_map[self.jobid] return job_map[self.task_id]
return -1 return -1
def assign_rank( def assign_rank(
@ -180,7 +181,12 @@ class RabitTracker:
""" """
def __init__( def __init__(
self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False self,
host_ip: str,
n_workers: int,
port: int = 0,
use_logger: bool = False,
sortby: str = "host",
) -> None: ) -> None:
"""A Python implementation of RABIT tracker. """A Python implementation of RABIT tracker.
@ -190,6 +196,13 @@ class RabitTracker:
Use logging.info for tracker print command. When set to False, Python print Use logging.info for tracker print command. When set to False, Python print
function is used instead. function is used instead.
sortby:
How to sort the workers for rank assignment. The default is host, but users
can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
deterministic rank assignment. Available options are:
- host
- task
""" """
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM) sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
sock.bind((host_ip, port)) sock.bind((host_ip, port))
@ -200,6 +213,7 @@ class RabitTracker:
self.thread: Optional[Thread] = None self.thread: Optional[Thread] = None
self.n_workers = n_workers self.n_workers = n_workers
self._use_logger = use_logger self._use_logger = use_logger
self._sortby = sortby
logging.info("start listen on %s:%d", host_ip, self.port) logging.info("start listen on %s:%d", host_ip, self.port)
def __del__(self) -> None: def __del__(self) -> None:
@ -223,7 +237,7 @@ class RabitTracker:
get environment variables for workers get environment variables for workers
can be passed in as args or envs can be passed in as args or envs
""" """
return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port} return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
tree_map: _TreeMap = {} tree_map: _TreeMap = {}
@ -296,8 +310,16 @@ class RabitTracker:
parent_map_[rmap[k]] = -1 parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_ return tree_map_, parent_map_, ring_map_
def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]:
if self._sortby == "host":
pending.sort(key=lambda s: s.host)
elif self._sortby == "task":
pending.sort(key=lambda s: s.task_id)
return pending
def accept_workers(self, n_workers: int) -> None: def accept_workers(self, n_workers: int) -> None:
"""Wait for all workers to connect to the tracker.""" """Wait for all workers to connect to the tracker."""
# set of nodes that finishes the job # set of nodes that finishes the job
shutdown: Dict[int, WorkerEntry] = {} shutdown: Dict[int, WorkerEntry] = {}
# set of nodes that is waiting for connections # set of nodes that is waiting for connections
@ -341,27 +363,32 @@ class RabitTracker:
assert todo_nodes assert todo_nodes
pending.append(s) pending.append(s)
if len(pending) == len(todo_nodes): if len(pending) == len(todo_nodes):
pending.sort(key=lambda x: x.host) pending = self._sort_pending(pending)
for s in pending: for s in pending:
rank = todo_nodes.pop(0) rank = todo_nodes.pop(0)
if s.jobid != 'NULL': if s.task_id != "NULL":
job_map[s.jobid] = rank job_map[s.task_id] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0: if s.wait_accept > 0:
wait_conn[rank] = s wait_conn[rank] = s
logging.debug('Received %s signal from %s; assign rank %d', logging.debug(
s.cmd, s.host, s.rank) "Received %s signal from %s; assign rank %d",
s.cmd,
s.host,
s.rank,
)
if not todo_nodes: if not todo_nodes:
logging.info('@tracker All of %d nodes getting started', n_workers) logging.info("@tracker All of %d nodes getting started", n_workers)
else: else:
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
logging.debug('Received %s signal from %d', s.cmd, s.rank) logging.debug("Received %s signal from %d", s.cmd, s.rank)
if s.wait_accept > 0: if s.wait_accept > 0:
wait_conn[rank] = s wait_conn[rank] = s
logging.info('@tracker All nodes finishes job') logging.info("@tracker All nodes finishes job")
def start(self, n_workers: int) -> None: def start(self, n_workers: int) -> None:
"""Strat the tracker, it will wait for `n_workers` to connect.""" """Strat the tracker, it will wait for `n_workers` to connect."""
def run() -> None: def run() -> None:
self.accept_workers(n_workers) self.accept_workers(n_workers)

View File

@ -4,6 +4,7 @@ import pytest
import testing as tm import testing as tm
import numpy as np import numpy as np
import sys import sys
import re
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -59,3 +60,34 @@ def test_rabit_ops():
with LocalCluster(n_workers=n_workers) as cluster: with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
run_rabit_ops(client, n_workers) run_rabit_ops(client, n_workers)
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
from test_with_dask import _get_client_workers
def local_test(worker_id):
with xgb.dask.RabitContext(args):
for val in args:
sval = val.decode("utf-8")
if sval.startswith("DMLC_TASK_ID"):
task_id = sval
break
matched = re.search(".*-([0-9]).*", task_id)
rank = xgb.rabit.get_rank()
# As long as the number of workers is lesser than 10, rank and worker id
# should be the same
assert rank == int(matched.group(1))
with LocalCluster(n_workers=8) as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
args = client.sync(
xgb.dask._get_rabit_args,
len(workers),
None,
client,
)
futures = client.map(local_test, range(len(workers)), workers=workers)
client.gather(futures)