From f53da412aa64ca5caa21c2f39d9a7691a1a0835f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 20 Oct 2021 12:49:36 +0800 Subject: [PATCH] Add typehint to tracker. (#7338) --- Makefile | 7 +- python-package/xgboost/dask.py | 6 +- python-package/xgboost/tracker.py | 201 +++++++++++++++++------------- tests/python/test_tracker.py | 2 +- 4 files changed, 122 insertions(+), 94 deletions(-) diff --git a/Makefile b/Makefile index d65cb56cc..1d86c2ed0 100644 --- a/Makefile +++ b/Makefile @@ -93,11 +93,14 @@ mypy: cd python-package; \ mypy ./xgboost/dask.py && \ mypy ./xgboost/rabit.py && \ + mypy ./xgboost/tracker.py && \ + mypy ./xgboost/sklearn.py && \ mypy ../demo/guide-python/external_memory.py && \ + mypy ../demo/guide-python/categorical.py && \ + mypy ../demo/guide-python/cat_in_the_dat.py && \ mypy ../tests/python-gpu/test_gpu_with_dask.py && \ mypy ../tests/python/test_data_iterator.py && \ - mypy ../tests/python-gpu/test_gpu_data_iterator.py && \ - mypy ./xgboost/sklearn.py || exit 1; \ + mypy ../tests/python-gpu/test_gpu_data_iterator.py || exit 1; \ mypy . || true ; clean: diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 5fa37dedd..b089561f4 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -138,10 +138,10 @@ def _multi_lock() -> Any: def _start_tracker(n_workers: int) -> Dict[str, Any]: """Start Rabit tracker """ - env = {'DMLC_NUM_WORKER': n_workers} + env: Dict[str, Union[int, str]] = {'DMLC_NUM_WORKER': n_workers} host = get_host_ip('auto') - rabit_context = RabitTracker(hostIP=host, nslave=n_workers, use_logger=False) - env.update(rabit_context.slave_envs()) + rabit_context = RabitTracker(hostIP=host, n_workers=n_workers, use_logger=False) + env.update(rabit_context.worker_envs()) rabit_context.start(n_workers) thread = Thread(target=rabit_context.join) diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index c64e04860..61e3a1a06 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -13,16 +13,21 @@ from threading import Thread import argparse import sys +from typing import Dict, List, Tuple, Union, Optional -class ExSocket(object): +_RingMap = Dict[int, Tuple[int, int]] +_TreeMap = Dict[int, List[int]] + + +class ExSocket: """ Extension of socket to handle recv and send of special data """ - def __init__(self, sock): + def __init__(self, sock: socket.socket) -> None: self.sock = sock - def recvall(self, nbytes): + def recvall(self, nbytes: int) -> bytes: res = [] nread = 0 while nread < nbytes: @@ -31,17 +36,17 @@ class ExSocket(object): res.append(chunk) return b''.join(res) - def recvint(self): + def recvint(self) -> int: return struct.unpack('@i', self.recvall(4))[0] - def sendint(self, n): + def sendint(self, n: int) -> None: self.sock.sendall(struct.pack('@i', n)) - def sendstr(self, s): + def sendstr(self, s: str) -> None: self.sendint(len(s)) self.sock.sendall(s.encode()) - def recvstr(self): + def recvstr(self) -> str: slen = self.recvint() return self.recvall(slen).decode() @@ -50,37 +55,44 @@ class ExSocket(object): kMagic = 0xff99 -def get_some_ip(host): +def get_some_ip(host: str) -> str: return socket.getaddrinfo(host, None)[0][4][0] -def get_family(addr): +def get_family(addr: str) -> int: return socket.getaddrinfo(addr, None)[0][0] -class SlaveEntry(object): - def __init__(self, sock, s_addr): - slave = ExSocket(sock) - self.sock = slave +class WorkerEntry: + def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]): + worker = ExSocket(sock) + self.sock = worker self.host = get_some_ip(s_addr[0]) - magic = slave.recvint() + magic = worker.recvint() assert magic == kMagic, f"invalid magic number={magic} from {self.host}" - slave.sendint(kMagic) - self.rank = slave.recvint() - self.world_size = slave.recvint() - self.jobid = slave.recvstr() - self.cmd = slave.recvstr() + worker.sendint(kMagic) + self.rank = worker.recvint() + self.world_size = worker.recvint() + self.jobid = worker.recvstr() + self.cmd = worker.recvstr() self.wait_accept = 0 - self.port = None + self.port: Optional[int] = None - def decide_rank(self, job_map): + def decide_rank(self, job_map: Dict[str, int]) -> int: if self.rank >= 0: return self.rank if self.jobid != 'NULL' and self.jobid in job_map: return job_map[self.jobid] return -1 - def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): + def assign_rank( + self, + rank: int, + wait_conn: Dict[int, "WorkerEntry"], + tree_map: _TreeMap, + parent_map: Dict[int, int], + ring_map: _RingMap, + ) -> List[int]: self.rank = rank nnset = set(tree_map[rank]) rprev, rnext = ring_map[rank] @@ -120,7 +132,9 @@ class SlaveEntry(object): self.sock.sendint(len(badset) - len(conset)) for r in conset: self.sock.sendstr(wait_conn[r].host) - self.sock.sendint(wait_conn[r].port) + port = wait_conn[r].port + assert port is not None + self.sock.sendint(port) self.sock.sendint(r) nerr = self.sock.recvint() if nerr != 0: @@ -138,13 +152,17 @@ class SlaveEntry(object): return rmset -class RabitTracker(object): +class RabitTracker: """ tracker for rabit """ def __init__( - self, hostIP, nslave, port=9091, port_end=9999, use_logger: bool = True + self, hostIP: str, + n_workers: int, + port: int = 9091, + port_end: int = 9999, + use_logger: bool = True, ) -> None: """A Python implementation of RABIT tracker. @@ -168,45 +186,45 @@ class RabitTracker(object): sock.listen(256) self.sock = sock self.hostIP = hostIP - self.thread = None - self.start_time = None - self.end_time = None - self.nslave = nslave + self.thread: Optional[Thread] = None + self.n_workers = n_workers self._use_logger = use_logger logging.info('start listen on %s:%d', hostIP, self.port) - def __del__(self): + def __del__(self) -> None: self.sock.close() @staticmethod - def get_neighbor(rank, nslave): + def get_neighbor(rank: int, n_workers: int) -> List[int]: rank = rank + 1 ret = [] if rank > 1: ret.append(rank // 2 - 1) - if rank * 2 - 1 < nslave: + if rank * 2 - 1 < n_workers: ret.append(rank * 2 - 1) - if rank * 2 < nslave: + if rank * 2 < n_workers: ret.append(rank * 2) return ret - def slave_envs(self): + def worker_envs(self) -> Dict[str, Union[str, int]]: """ - get enviroment variables for slaves + get enviroment variables for workers can be passed in as args or envs """ return {'DMLC_TRACKER_URI': self.hostIP, 'DMLC_TRACKER_PORT': self.port} - def get_tree(self, nslave): - tree_map = {} - parent_map = {} - for r in range(nslave): - tree_map[r] = self.get_neighbor(r, nslave) + def get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: + tree_map: _TreeMap = {} + parent_map: Dict[int, int] = {} + for r in range(n_workers): + tree_map[r] = self.get_neighbor(r, n_workers) parent_map[r] = (r + 1) // 2 - 1 return tree_map, parent_map - def find_share_ring(self, tree_map, parent_map, r): + def find_share_ring( + self, tree_map: _TreeMap, parent_map: Dict[int, int], r: int + ) -> List[int]: """ get a ring structure that tends to share nodes with the tree return a list starting from r @@ -225,63 +243,65 @@ class RabitTracker(object): rlst += vlst return rlst - def get_ring(self, tree_map, parent_map): + def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap: """ get a ring connection used to recover local data """ assert parent_map[0] == -1 rlst = self.find_share_ring(tree_map, parent_map, 0) assert len(rlst) == len(tree_map) - ring_map = {} - nslave = len(tree_map) - for r in range(nslave): - rprev = (r + nslave - 1) % nslave - rnext = (r + 1) % nslave + ring_map: _RingMap = {} + n_workers = len(tree_map) + for r in range(n_workers): + rprev = (r + n_workers - 1) % n_workers + rnext = (r + 1) % n_workers ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) return ring_map - def get_link_map(self, nslave): + def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]: """ get the link map, this is a bit hacky, call for better algorithm to place similar nodes together """ - tree_map, parent_map = self.get_tree(nslave) + tree_map, parent_map = self.get_tree(n_workers) ring_map = self.get_ring(tree_map, parent_map) rmap = {0: 0} k = 0 - for i in range(nslave - 1): + for i in range(n_workers - 1): k = ring_map[k][1] rmap[k] = i + 1 - ring_map_ = {} - tree_map_ = {} - parent_map_ = {} + ring_map_: _RingMap = {} + tree_map_: _TreeMap = {} + parent_map_: Dict[int, int] = {} for k, v in ring_map.items(): ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]]) - for k, v in tree_map.items(): - tree_map_[rmap[k]] = [rmap[x] for x in v] - for k, v in parent_map.items(): + for k, tree_nodes in tree_map.items(): + tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes] + for k, parent in parent_map.items(): if k != 0: - parent_map_[rmap[k]] = rmap[v] + parent_map_[rmap[k]] = rmap[parent] else: parent_map_[rmap[k]] = -1 return tree_map_, parent_map_, ring_map_ - def accept_slaves(self, nslave): + def accept_workers(self, n_workers: int) -> None: # set of nodes that finishs the job - shutdown = {} + shutdown: Dict[int, WorkerEntry] = {} # set of nodes that is waiting for connections - wait_conn = {} + wait_conn: Dict[int, WorkerEntry] = {} # maps job id to rank - job_map = {} + job_map: Dict[str, int] = {} # list of workers that is pending to be assigned rank - pending = [] + pending: List[WorkerEntry] = [] # lazy initialize tree_map tree_map = None - while len(shutdown) != nslave: + start_time = time.time() + + while len(shutdown) != n_workers: fd, s_addr = self.sock.accept() - s = SlaveEntry(fd, s_addr) + s = WorkerEntry(fd, s_addr) if s.cmd == 'print': msg = s.sock.recvstr() # On dask we use print to avoid setting global verbosity. @@ -297,16 +317,16 @@ class RabitTracker(object): logging.debug('Received %s signal from %d', s.cmd, s.rank) continue assert s.cmd in ("start", "recover") - # lazily initialize the slaves + # lazily initialize the workers if tree_map is None: assert s.cmd == 'start' if s.world_size > 0: - nslave = s.world_size - tree_map, parent_map, ring_map = self.get_link_map(nslave) + n_workers = s.world_size + tree_map, parent_map, ring_map = self.get_link_map(n_workers) # set of nodes that is pending for getting up - todo_nodes = list(range(nslave)) + todo_nodes = list(range(n_workers)) else: - assert s.world_size in (-1, nslave) + assert s.world_size in (-1, n_workers) if s.cmd == 'recover': assert s.rank >= 0 @@ -327,34 +347,35 @@ class RabitTracker(object): logging.debug('Received %s signal from %s; assign rank %d', s.cmd, s.host, s.rank) if not todo_nodes: - logging.info('@tracker All of %d nodes getting started', nslave) - self.start_time = time.time() + logging.info('@tracker All of %d nodes getting started', n_workers) else: s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) logging.debug('Received %s signal from %d', s.cmd, s.rank) if s.wait_accept > 0: wait_conn[rank] = s logging.info('@tracker All nodes finishes job') - self.end_time = time.time() - logging.info('@tracker %s secs between node start and job finish', - str(self.end_time - self.start_time)) + end_time = time.time() + logging.info( + '@tracker %s secs between node start and job finish', + str(end_time - start_time) + ) - def start(self, nslave): - def run(): - self.accept_slaves(nslave) + def start(self, n_workers: int) -> None: + def run() -> None: + self.accept_workers(n_workers) self.thread = Thread(target=run, args=(), daemon=True) self.thread.start() - def join(self): - while self.thread.is_alive(): + def join(self) -> None: + while self.thread is not None and self.thread.is_alive(): self.thread.join(100) - def alive(self): - return self.thread.is_alive() + def alive(self) -> bool: + return self.thread is not None and self.thread.is_alive() -def get_host_ip(hostIP=None): +def get_host_ip(hostIP: Optional[str] = None) -> str: if hostIP is None or hostIP == 'auto': hostIP = 'ip' @@ -374,10 +395,12 @@ def get_host_ip(hostIP=None): # doesn't have to be reachable s.connect(('10.255.255.255', 1)) hostIP = s.getsockname()[0] + + assert hostIP is not None return hostIP -def start_rabit_tracker(args): +def start_rabit_tracker(args: argparse.Namespace) -> None: """Standalone function to start rabit tracker. Parameters @@ -386,8 +409,8 @@ def start_rabit_tracker(args): """ envs = {'DMLC_NUM_WORKER': args.num_workers, 'DMLC_NUM_SERVER': args.num_servers} - rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), nslave=args.num_workers) - envs.update(rabit.slave_envs()) + rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), n_workers=args.num_workers) + envs.update(rabit.worker_envs()) rabit.start(args.num_workers) sys.stdout.write('DMLC_TRACKER_ENV_START\n') # simply write configuration to stdout @@ -398,13 +421,15 @@ def start_rabit_tracker(args): rabit.join() -def main(): +def main() -> None: """Main function if tracker is executed in standalone mode.""" parser = argparse.ArgumentParser(description='Rabit Tracker start.') parser.add_argument('--num-workers', required=True, type=int, help='Number of worker proccess to be launched.') - parser.add_argument('--num-servers', default=0, type=int, - help='Number of server process to be launched. Only used in PS jobs.') + parser.add_argument( + '--num-servers', default=0, type=int, + help='Number of server process to be launched. Only used in PS jobs.' + ) parser.add_argument('--host-ip', default=None, type=str, help=('Host IP addressed, this is only needed ' + 'if the host IP cannot be automatically guessed.')) diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index a6490f50c..e86c7c72a 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -10,7 +10,7 @@ if sys.platform.startswith("win"): def test_rabit_tracker(): - tracker = RabitTracker(hostIP='127.0.0.1', nslave=1) + tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1) tracker.start(1) rabit_env = [ str.encode('DMLC_TRACKER_URI=127.0.0.1'),