parent
2e6444b342
commit
b18c984035
@ -177,9 +177,11 @@ def _try_start_tracker(
|
|||||||
use_logger=False,
|
use_logger=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(addrs[0], str) or addrs[0] is None
|
addr = addrs[0]
|
||||||
|
assert isinstance(addr, str) or addr is None
|
||||||
|
host_ip = get_host_ip(addr)
|
||||||
rabit_context = RabitTracker(
|
rabit_context = RabitTracker(
|
||||||
host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
|
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
|
||||||
)
|
)
|
||||||
env.update(rabit_context.worker_envs())
|
env.update(rabit_context.worker_envs())
|
||||||
rabit_context.start(n_workers)
|
rabit_context.start(n_workers)
|
||||||
@ -229,8 +231,16 @@ class RabitContext:
|
|||||||
def __init__(self, args: List[bytes]) -> None:
|
def __init__(self, args: List[bytes]) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
|
with distributed.worker_client() as client:
|
||||||
|
info = client.scheduler_info()
|
||||||
|
w = info["workers"][worker.address]
|
||||||
|
wid = w["id"]
|
||||||
|
# We use task ID for rank assignment which makes the RABIT rank consistent (but
|
||||||
|
# not the same as task ID is string and "10" is sorted before "2") with dask
|
||||||
|
# worker ID. This outsources the rank assignment to dask and prevents
|
||||||
|
# non-deterministic issue.
|
||||||
self.args.append(
|
self.args.append(
|
||||||
("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode()
|
(f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode()
|
||||||
)
|
)
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
def __enter__(self) -> None:
|
||||||
@ -870,6 +880,8 @@ async def _get_rabit_args(
|
|||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
sched_addr = None
|
sched_addr = None
|
||||||
|
|
||||||
|
# make sure all workers are online so that we can obtain reliable scheduler_info
|
||||||
|
client.wait_for_workers(n_workers)
|
||||||
env = await client.run_on_scheduler(
|
env = await client.run_on_scheduler(
|
||||||
_start_tracker, n_workers, sched_addr, user_addr
|
_start_tracker, n_workers, sched_addr, user_addr
|
||||||
)
|
)
|
||||||
|
|||||||
@ -32,15 +32,15 @@ class ExSocket:
|
|||||||
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
||||||
nread += len(chunk)
|
nread += len(chunk)
|
||||||
res.append(chunk)
|
res.append(chunk)
|
||||||
return b''.join(res)
|
return b"".join(res)
|
||||||
|
|
||||||
def recvint(self) -> int:
|
def recvint(self) -> int:
|
||||||
"""Receive an integer of 32 bytes"""
|
"""Receive an integer of 32 bytes"""
|
||||||
return struct.unpack('@i', self.recvall(4))[0]
|
return struct.unpack("@i", self.recvall(4))[0]
|
||||||
|
|
||||||
def sendint(self, value: int) -> None:
|
def sendint(self, value: int) -> None:
|
||||||
"""Send an integer of 32 bytes"""
|
"""Send an integer of 32 bytes"""
|
||||||
self.sock.sendall(struct.pack('@i', value))
|
self.sock.sendall(struct.pack("@i", value))
|
||||||
|
|
||||||
def sendstr(self, value: str) -> None:
|
def sendstr(self, value: str) -> None:
|
||||||
"""Send a Python string"""
|
"""Send a Python string"""
|
||||||
@ -69,6 +69,7 @@ def get_family(addr: str) -> int:
|
|||||||
|
|
||||||
class WorkerEntry:
|
class WorkerEntry:
|
||||||
"""Hanlder to each worker."""
|
"""Hanlder to each worker."""
|
||||||
|
|
||||||
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
|
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
|
||||||
worker = ExSocket(sock)
|
worker = ExSocket(sock)
|
||||||
self.sock = worker
|
self.sock = worker
|
||||||
@ -78,7 +79,7 @@ class WorkerEntry:
|
|||||||
worker.sendint(MAGIC_NUM)
|
worker.sendint(MAGIC_NUM)
|
||||||
self.rank = worker.recvint()
|
self.rank = worker.recvint()
|
||||||
self.world_size = worker.recvint()
|
self.world_size = worker.recvint()
|
||||||
self.jobid = worker.recvstr()
|
self.task_id = worker.recvstr()
|
||||||
self.cmd = worker.recvstr()
|
self.cmd = worker.recvstr()
|
||||||
self.wait_accept = 0
|
self.wait_accept = 0
|
||||||
self.port: Optional[int] = None
|
self.port: Optional[int] = None
|
||||||
@ -96,8 +97,8 @@ class WorkerEntry:
|
|||||||
"""Get the rank of current entry."""
|
"""Get the rank of current entry."""
|
||||||
if self.rank >= 0:
|
if self.rank >= 0:
|
||||||
return self.rank
|
return self.rank
|
||||||
if self.jobid != 'NULL' and self.jobid in job_map:
|
if self.task_id != "NULL" and self.task_id in job_map:
|
||||||
return job_map[self.jobid]
|
return job_map[self.task_id]
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def assign_rank(
|
def assign_rank(
|
||||||
@ -180,7 +181,12 @@ class RabitTracker:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False
|
self,
|
||||||
|
host_ip: str,
|
||||||
|
n_workers: int,
|
||||||
|
port: int = 0,
|
||||||
|
use_logger: bool = False,
|
||||||
|
sortby: str = "host",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A Python implementation of RABIT tracker.
|
"""A Python implementation of RABIT tracker.
|
||||||
|
|
||||||
@ -190,6 +196,13 @@ class RabitTracker:
|
|||||||
Use logging.info for tracker print command. When set to False, Python print
|
Use logging.info for tracker print command. When set to False, Python print
|
||||||
function is used instead.
|
function is used instead.
|
||||||
|
|
||||||
|
sortby:
|
||||||
|
How to sort the workers for rank assignment. The default is host, but users
|
||||||
|
can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
|
||||||
|
deterministic rank assignment. Available options are:
|
||||||
|
- host
|
||||||
|
- task
|
||||||
|
|
||||||
"""
|
"""
|
||||||
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
|
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
|
||||||
sock.bind((host_ip, port))
|
sock.bind((host_ip, port))
|
||||||
@ -200,6 +213,7 @@ class RabitTracker:
|
|||||||
self.thread: Optional[Thread] = None
|
self.thread: Optional[Thread] = None
|
||||||
self.n_workers = n_workers
|
self.n_workers = n_workers
|
||||||
self._use_logger = use_logger
|
self._use_logger = use_logger
|
||||||
|
self._sortby = sortby
|
||||||
logging.info("start listen on %s:%d", host_ip, self.port)
|
logging.info("start listen on %s:%d", host_ip, self.port)
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
@ -223,7 +237,7 @@ class RabitTracker:
|
|||||||
get environment variables for workers
|
get environment variables for workers
|
||||||
can be passed in as args or envs
|
can be passed in as args or envs
|
||||||
"""
|
"""
|
||||||
return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port}
|
return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
|
||||||
|
|
||||||
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
|
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
|
||||||
tree_map: _TreeMap = {}
|
tree_map: _TreeMap = {}
|
||||||
@ -296,8 +310,16 @@ class RabitTracker:
|
|||||||
parent_map_[rmap[k]] = -1
|
parent_map_[rmap[k]] = -1
|
||||||
return tree_map_, parent_map_, ring_map_
|
return tree_map_, parent_map_, ring_map_
|
||||||
|
|
||||||
|
def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]:
|
||||||
|
if self._sortby == "host":
|
||||||
|
pending.sort(key=lambda s: s.host)
|
||||||
|
elif self._sortby == "task":
|
||||||
|
pending.sort(key=lambda s: s.task_id)
|
||||||
|
return pending
|
||||||
|
|
||||||
def accept_workers(self, n_workers: int) -> None:
|
def accept_workers(self, n_workers: int) -> None:
|
||||||
"""Wait for all workers to connect to the tracker."""
|
"""Wait for all workers to connect to the tracker."""
|
||||||
|
|
||||||
# set of nodes that finishes the job
|
# set of nodes that finishes the job
|
||||||
shutdown: Dict[int, WorkerEntry] = {}
|
shutdown: Dict[int, WorkerEntry] = {}
|
||||||
# set of nodes that is waiting for connections
|
# set of nodes that is waiting for connections
|
||||||
@ -341,27 +363,32 @@ class RabitTracker:
|
|||||||
assert todo_nodes
|
assert todo_nodes
|
||||||
pending.append(s)
|
pending.append(s)
|
||||||
if len(pending) == len(todo_nodes):
|
if len(pending) == len(todo_nodes):
|
||||||
pending.sort(key=lambda x: x.host)
|
pending = self._sort_pending(pending)
|
||||||
for s in pending:
|
for s in pending:
|
||||||
rank = todo_nodes.pop(0)
|
rank = todo_nodes.pop(0)
|
||||||
if s.jobid != 'NULL':
|
if s.task_id != "NULL":
|
||||||
job_map[s.jobid] = rank
|
job_map[s.task_id] = rank
|
||||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||||
if s.wait_accept > 0:
|
if s.wait_accept > 0:
|
||||||
wait_conn[rank] = s
|
wait_conn[rank] = s
|
||||||
logging.debug('Received %s signal from %s; assign rank %d',
|
logging.debug(
|
||||||
s.cmd, s.host, s.rank)
|
"Received %s signal from %s; assign rank %d",
|
||||||
|
s.cmd,
|
||||||
|
s.host,
|
||||||
|
s.rank,
|
||||||
|
)
|
||||||
if not todo_nodes:
|
if not todo_nodes:
|
||||||
logging.info('@tracker All of %d nodes getting started', n_workers)
|
logging.info("@tracker All of %d nodes getting started", n_workers)
|
||||||
else:
|
else:
|
||||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||||
logging.debug('Received %s signal from %d', s.cmd, s.rank)
|
logging.debug("Received %s signal from %d", s.cmd, s.rank)
|
||||||
if s.wait_accept > 0:
|
if s.wait_accept > 0:
|
||||||
wait_conn[rank] = s
|
wait_conn[rank] = s
|
||||||
logging.info('@tracker All nodes finishes job')
|
logging.info("@tracker All nodes finishes job")
|
||||||
|
|
||||||
def start(self, n_workers: int) -> None:
|
def start(self, n_workers: int) -> None:
|
||||||
"""Strat the tracker, it will wait for `n_workers` to connect."""
|
"""Strat the tracker, it will wait for `n_workers` to connect."""
|
||||||
|
|
||||||
def run() -> None:
|
def run() -> None:
|
||||||
self.accept_workers(n_workers)
|
self.accept_workers(n_workers)
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import pytest
|
|||||||
import testing as tm
|
import testing as tm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sys
|
import sys
|
||||||
|
import re
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
@ -59,3 +60,34 @@ def test_rabit_ops():
|
|||||||
with LocalCluster(n_workers=n_workers) as cluster:
|
with LocalCluster(n_workers=n_workers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
run_rabit_ops(client, n_workers)
|
run_rabit_ops(client, n_workers)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rank_assignment() -> None:
|
||||||
|
from distributed import Client, LocalCluster
|
||||||
|
from test_with_dask import _get_client_workers
|
||||||
|
|
||||||
|
def local_test(worker_id):
|
||||||
|
with xgb.dask.RabitContext(args):
|
||||||
|
for val in args:
|
||||||
|
sval = val.decode("utf-8")
|
||||||
|
if sval.startswith("DMLC_TASK_ID"):
|
||||||
|
task_id = sval
|
||||||
|
break
|
||||||
|
matched = re.search(".*-([0-9]).*", task_id)
|
||||||
|
rank = xgb.rabit.get_rank()
|
||||||
|
# As long as the number of workers is lesser than 10, rank and worker id
|
||||||
|
# should be the same
|
||||||
|
assert rank == int(matched.group(1))
|
||||||
|
|
||||||
|
with LocalCluster(n_workers=8) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
workers = _get_client_workers(client)
|
||||||
|
args = client.sync(
|
||||||
|
xgb.dask._get_rabit_args,
|
||||||
|
len(workers),
|
||||||
|
None,
|
||||||
|
client,
|
||||||
|
)
|
||||||
|
|
||||||
|
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||||
|
client.gather(futures)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user