parent
2e6444b342
commit
b18c984035
@ -177,9 +177,11 @@ def _try_start_tracker(
|
||||
use_logger=False,
|
||||
)
|
||||
else:
|
||||
assert isinstance(addrs[0], str) or addrs[0] is None
|
||||
addr = addrs[0]
|
||||
assert isinstance(addr, str) or addr is None
|
||||
host_ip = get_host_ip(addr)
|
||||
rabit_context = RabitTracker(
|
||||
host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
|
||||
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
|
||||
)
|
||||
env.update(rabit_context.worker_envs())
|
||||
rabit_context.start(n_workers)
|
||||
@ -229,8 +231,16 @@ class RabitContext:
|
||||
def __init__(self, args: List[bytes]) -> None:
|
||||
self.args = args
|
||||
worker = distributed.get_worker()
|
||||
with distributed.worker_client() as client:
|
||||
info = client.scheduler_info()
|
||||
w = info["workers"][worker.address]
|
||||
wid = w["id"]
|
||||
# We use task ID for rank assignment which makes the RABIT rank consistent (but
|
||||
# not the same as task ID is string and "10" is sorted before "2") with dask
|
||||
# worker ID. This outsources the rank assignment to dask and prevents
|
||||
# non-deterministic issue.
|
||||
self.args.append(
|
||||
("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode()
|
||||
(f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode()
|
||||
)
|
||||
|
||||
def __enter__(self) -> None:
|
||||
@ -870,6 +880,8 @@ async def _get_rabit_args(
|
||||
except Exception: # pylint: disable=broad-except
|
||||
sched_addr = None
|
||||
|
||||
# make sure all workers are online so that we can obtain reliable scheduler_info
|
||||
client.wait_for_workers(n_workers)
|
||||
env = await client.run_on_scheduler(
|
||||
_start_tracker, n_workers, sched_addr, user_addr
|
||||
)
|
||||
|
||||
@ -32,15 +32,15 @@ class ExSocket:
|
||||
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
||||
nread += len(chunk)
|
||||
res.append(chunk)
|
||||
return b''.join(res)
|
||||
return b"".join(res)
|
||||
|
||||
def recvint(self) -> int:
|
||||
"""Receive an integer of 32 bytes"""
|
||||
return struct.unpack('@i', self.recvall(4))[0]
|
||||
return struct.unpack("@i", self.recvall(4))[0]
|
||||
|
||||
def sendint(self, value: int) -> None:
|
||||
"""Send an integer of 32 bytes"""
|
||||
self.sock.sendall(struct.pack('@i', value))
|
||||
self.sock.sendall(struct.pack("@i", value))
|
||||
|
||||
def sendstr(self, value: str) -> None:
|
||||
"""Send a Python string"""
|
||||
@ -69,6 +69,7 @@ def get_family(addr: str) -> int:
|
||||
|
||||
class WorkerEntry:
|
||||
"""Hanlder to each worker."""
|
||||
|
||||
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
|
||||
worker = ExSocket(sock)
|
||||
self.sock = worker
|
||||
@ -78,7 +79,7 @@ class WorkerEntry:
|
||||
worker.sendint(MAGIC_NUM)
|
||||
self.rank = worker.recvint()
|
||||
self.world_size = worker.recvint()
|
||||
self.jobid = worker.recvstr()
|
||||
self.task_id = worker.recvstr()
|
||||
self.cmd = worker.recvstr()
|
||||
self.wait_accept = 0
|
||||
self.port: Optional[int] = None
|
||||
@ -96,8 +97,8 @@ class WorkerEntry:
|
||||
"""Get the rank of current entry."""
|
||||
if self.rank >= 0:
|
||||
return self.rank
|
||||
if self.jobid != 'NULL' and self.jobid in job_map:
|
||||
return job_map[self.jobid]
|
||||
if self.task_id != "NULL" and self.task_id in job_map:
|
||||
return job_map[self.task_id]
|
||||
return -1
|
||||
|
||||
def assign_rank(
|
||||
@ -180,7 +181,12 @@ class RabitTracker:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False
|
||||
self,
|
||||
host_ip: str,
|
||||
n_workers: int,
|
||||
port: int = 0,
|
||||
use_logger: bool = False,
|
||||
sortby: str = "host",
|
||||
) -> None:
|
||||
"""A Python implementation of RABIT tracker.
|
||||
|
||||
@ -190,6 +196,13 @@ class RabitTracker:
|
||||
Use logging.info for tracker print command. When set to False, Python print
|
||||
function is used instead.
|
||||
|
||||
sortby:
|
||||
How to sort the workers for rank assignment. The default is host, but users
|
||||
can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
|
||||
deterministic rank assignment. Available options are:
|
||||
- host
|
||||
- task
|
||||
|
||||
"""
|
||||
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
|
||||
sock.bind((host_ip, port))
|
||||
@ -200,6 +213,7 @@ class RabitTracker:
|
||||
self.thread: Optional[Thread] = None
|
||||
self.n_workers = n_workers
|
||||
self._use_logger = use_logger
|
||||
self._sortby = sortby
|
||||
logging.info("start listen on %s:%d", host_ip, self.port)
|
||||
|
||||
def __del__(self) -> None:
|
||||
@ -223,7 +237,7 @@ class RabitTracker:
|
||||
get environment variables for workers
|
||||
can be passed in as args or envs
|
||||
"""
|
||||
return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port}
|
||||
return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
|
||||
|
||||
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
|
||||
tree_map: _TreeMap = {}
|
||||
@ -296,8 +310,16 @@ class RabitTracker:
|
||||
parent_map_[rmap[k]] = -1
|
||||
return tree_map_, parent_map_, ring_map_
|
||||
|
||||
def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]:
|
||||
if self._sortby == "host":
|
||||
pending.sort(key=lambda s: s.host)
|
||||
elif self._sortby == "task":
|
||||
pending.sort(key=lambda s: s.task_id)
|
||||
return pending
|
||||
|
||||
def accept_workers(self, n_workers: int) -> None:
|
||||
"""Wait for all workers to connect to the tracker."""
|
||||
|
||||
# set of nodes that finishes the job
|
||||
shutdown: Dict[int, WorkerEntry] = {}
|
||||
# set of nodes that is waiting for connections
|
||||
@ -341,27 +363,32 @@ class RabitTracker:
|
||||
assert todo_nodes
|
||||
pending.append(s)
|
||||
if len(pending) == len(todo_nodes):
|
||||
pending.sort(key=lambda x: x.host)
|
||||
pending = self._sort_pending(pending)
|
||||
for s in pending:
|
||||
rank = todo_nodes.pop(0)
|
||||
if s.jobid != 'NULL':
|
||||
job_map[s.jobid] = rank
|
||||
if s.task_id != "NULL":
|
||||
job_map[s.task_id] = rank
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
logging.debug('Received %s signal from %s; assign rank %d',
|
||||
s.cmd, s.host, s.rank)
|
||||
logging.debug(
|
||||
"Received %s signal from %s; assign rank %d",
|
||||
s.cmd,
|
||||
s.host,
|
||||
s.rank,
|
||||
)
|
||||
if not todo_nodes:
|
||||
logging.info('@tracker All of %d nodes getting started', n_workers)
|
||||
logging.info("@tracker All of %d nodes getting started", n_workers)
|
||||
else:
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
logging.debug('Received %s signal from %d', s.cmd, s.rank)
|
||||
logging.debug("Received %s signal from %d", s.cmd, s.rank)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
logging.info('@tracker All nodes finishes job')
|
||||
logging.info("@tracker All nodes finishes job")
|
||||
|
||||
def start(self, n_workers: int) -> None:
|
||||
"""Strat the tracker, it will wait for `n_workers` to connect."""
|
||||
|
||||
def run() -> None:
|
||||
self.accept_workers(n_workers)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import sys
|
||||
import re
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
@ -59,3 +60,34 @@ def test_rabit_ops():
|
||||
with LocalCluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_rabit_ops(client, n_workers)
|
||||
|
||||
|
||||
def test_rank_assignment() -> None:
|
||||
from distributed import Client, LocalCluster
|
||||
from test_with_dask import _get_client_workers
|
||||
|
||||
def local_test(worker_id):
|
||||
with xgb.dask.RabitContext(args):
|
||||
for val in args:
|
||||
sval = val.decode("utf-8")
|
||||
if sval.startswith("DMLC_TASK_ID"):
|
||||
task_id = sval
|
||||
break
|
||||
matched = re.search(".*-([0-9]).*", task_id)
|
||||
rank = xgb.rabit.get_rank()
|
||||
# As long as the number of workers is lesser than 10, rank and worker id
|
||||
# should be the same
|
||||
assert rank == int(matched.group(1))
|
||||
|
||||
with LocalCluster(n_workers=8) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = _get_client_workers(client)
|
||||
args = client.sync(
|
||||
xgb.dask._get_rabit_args,
|
||||
len(workers),
|
||||
None,
|
||||
client,
|
||||
)
|
||||
|
||||
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||
client.gather(futures)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user