Add typehint to tracker. (#7338)

This commit is contained in:
Jiaming Yuan 2021-10-20 12:49:36 +08:00 committed by GitHub
parent 5ff210ed75
commit f53da412aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 94 deletions

View File

@ -93,11 +93,14 @@ mypy:
cd python-package; \ cd python-package; \
mypy ./xgboost/dask.py && \ mypy ./xgboost/dask.py && \
mypy ./xgboost/rabit.py && \ mypy ./xgboost/rabit.py && \
mypy ./xgboost/tracker.py && \
mypy ./xgboost/sklearn.py && \
mypy ../demo/guide-python/external_memory.py && \ mypy ../demo/guide-python/external_memory.py && \
mypy ../demo/guide-python/categorical.py && \
mypy ../demo/guide-python/cat_in_the_dat.py && \
mypy ../tests/python-gpu/test_gpu_with_dask.py && \ mypy ../tests/python-gpu/test_gpu_with_dask.py && \
mypy ../tests/python/test_data_iterator.py && \ mypy ../tests/python/test_data_iterator.py && \
mypy ../tests/python-gpu/test_gpu_data_iterator.py && \ mypy ../tests/python-gpu/test_gpu_data_iterator.py || exit 1; \
mypy ./xgboost/sklearn.py || exit 1; \
mypy . || true ; mypy . || true ;
clean: clean:

View File

@ -138,10 +138,10 @@ def _multi_lock() -> Any:
def _start_tracker(n_workers: int) -> Dict[str, Any]: def _start_tracker(n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker """ """Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers} env: Dict[str, Union[int, str]] = {'DMLC_NUM_WORKER': n_workers}
host = get_host_ip('auto') host = get_host_ip('auto')
rabit_context = RabitTracker(hostIP=host, nslave=n_workers, use_logger=False) rabit_context = RabitTracker(hostIP=host, n_workers=n_workers, use_logger=False)
env.update(rabit_context.slave_envs()) env.update(rabit_context.worker_envs())
rabit_context.start(n_workers) rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join) thread = Thread(target=rabit_context.join)

View File

@ -13,16 +13,21 @@ from threading import Thread
import argparse import argparse
import sys import sys
from typing import Dict, List, Tuple, Union, Optional
class ExSocket(object): _RingMap = Dict[int, Tuple[int, int]]
_TreeMap = Dict[int, List[int]]
class ExSocket:
""" """
Extension of socket to handle recv and send of special data Extension of socket to handle recv and send of special data
""" """
def __init__(self, sock): def __init__(self, sock: socket.socket) -> None:
self.sock = sock self.sock = sock
def recvall(self, nbytes): def recvall(self, nbytes: int) -> bytes:
res = [] res = []
nread = 0 nread = 0
while nread < nbytes: while nread < nbytes:
@ -31,17 +36,17 @@ class ExSocket(object):
res.append(chunk) res.append(chunk)
return b''.join(res) return b''.join(res)
def recvint(self): def recvint(self) -> int:
return struct.unpack('@i', self.recvall(4))[0] return struct.unpack('@i', self.recvall(4))[0]
def sendint(self, n): def sendint(self, n: int) -> None:
self.sock.sendall(struct.pack('@i', n)) self.sock.sendall(struct.pack('@i', n))
def sendstr(self, s): def sendstr(self, s: str) -> None:
self.sendint(len(s)) self.sendint(len(s))
self.sock.sendall(s.encode()) self.sock.sendall(s.encode())
def recvstr(self): def recvstr(self) -> str:
slen = self.recvint() slen = self.recvint()
return self.recvall(slen).decode() return self.recvall(slen).decode()
@ -50,37 +55,44 @@ class ExSocket(object):
kMagic = 0xff99 kMagic = 0xff99
def get_some_ip(host): def get_some_ip(host: str) -> str:
return socket.getaddrinfo(host, None)[0][4][0] return socket.getaddrinfo(host, None)[0][4][0]
def get_family(addr): def get_family(addr: str) -> int:
return socket.getaddrinfo(addr, None)[0][0] return socket.getaddrinfo(addr, None)[0][0]
class SlaveEntry(object): class WorkerEntry:
def __init__(self, sock, s_addr): def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
slave = ExSocket(sock) worker = ExSocket(sock)
self.sock = slave self.sock = worker
self.host = get_some_ip(s_addr[0]) self.host = get_some_ip(s_addr[0])
magic = slave.recvint() magic = worker.recvint()
assert magic == kMagic, f"invalid magic number={magic} from {self.host}" assert magic == kMagic, f"invalid magic number={magic} from {self.host}"
slave.sendint(kMagic) worker.sendint(kMagic)
self.rank = slave.recvint() self.rank = worker.recvint()
self.world_size = slave.recvint() self.world_size = worker.recvint()
self.jobid = slave.recvstr() self.jobid = worker.recvstr()
self.cmd = slave.recvstr() self.cmd = worker.recvstr()
self.wait_accept = 0 self.wait_accept = 0
self.port = None self.port: Optional[int] = None
def decide_rank(self, job_map): def decide_rank(self, job_map: Dict[str, int]) -> int:
if self.rank >= 0: if self.rank >= 0:
return self.rank return self.rank
if self.jobid != 'NULL' and self.jobid in job_map: if self.jobid != 'NULL' and self.jobid in job_map:
return job_map[self.jobid] return job_map[self.jobid]
return -1 return -1
def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): def assign_rank(
self,
rank: int,
wait_conn: Dict[int, "WorkerEntry"],
tree_map: _TreeMap,
parent_map: Dict[int, int],
ring_map: _RingMap,
) -> List[int]:
self.rank = rank self.rank = rank
nnset = set(tree_map[rank]) nnset = set(tree_map[rank])
rprev, rnext = ring_map[rank] rprev, rnext = ring_map[rank]
@ -120,7 +132,9 @@ class SlaveEntry(object):
self.sock.sendint(len(badset) - len(conset)) self.sock.sendint(len(badset) - len(conset))
for r in conset: for r in conset:
self.sock.sendstr(wait_conn[r].host) self.sock.sendstr(wait_conn[r].host)
self.sock.sendint(wait_conn[r].port) port = wait_conn[r].port
assert port is not None
self.sock.sendint(port)
self.sock.sendint(r) self.sock.sendint(r)
nerr = self.sock.recvint() nerr = self.sock.recvint()
if nerr != 0: if nerr != 0:
@ -138,13 +152,17 @@ class SlaveEntry(object):
return rmset return rmset
class RabitTracker(object): class RabitTracker:
""" """
tracker for rabit tracker for rabit
""" """
def __init__( def __init__(
self, hostIP, nslave, port=9091, port_end=9999, use_logger: bool = True self, hostIP: str,
n_workers: int,
port: int = 9091,
port_end: int = 9999,
use_logger: bool = True,
) -> None: ) -> None:
"""A Python implementation of RABIT tracker. """A Python implementation of RABIT tracker.
@ -168,45 +186,45 @@ class RabitTracker(object):
sock.listen(256) sock.listen(256)
self.sock = sock self.sock = sock
self.hostIP = hostIP self.hostIP = hostIP
self.thread = None self.thread: Optional[Thread] = None
self.start_time = None self.n_workers = n_workers
self.end_time = None
self.nslave = nslave
self._use_logger = use_logger self._use_logger = use_logger
logging.info('start listen on %s:%d', hostIP, self.port) logging.info('start listen on %s:%d', hostIP, self.port)
def __del__(self): def __del__(self) -> None:
self.sock.close() self.sock.close()
@staticmethod @staticmethod
def get_neighbor(rank, nslave): def get_neighbor(rank: int, n_workers: int) -> List[int]:
rank = rank + 1 rank = rank + 1
ret = [] ret = []
if rank > 1: if rank > 1:
ret.append(rank // 2 - 1) ret.append(rank // 2 - 1)
if rank * 2 - 1 < nslave: if rank * 2 - 1 < n_workers:
ret.append(rank * 2 - 1) ret.append(rank * 2 - 1)
if rank * 2 < nslave: if rank * 2 < n_workers:
ret.append(rank * 2) ret.append(rank * 2)
return ret return ret
def slave_envs(self): def worker_envs(self) -> Dict[str, Union[str, int]]:
""" """
get enviroment variables for slaves get enviroment variables for workers
can be passed in as args or envs can be passed in as args or envs
""" """
return {'DMLC_TRACKER_URI': self.hostIP, return {'DMLC_TRACKER_URI': self.hostIP,
'DMLC_TRACKER_PORT': self.port} 'DMLC_TRACKER_PORT': self.port}
def get_tree(self, nslave): def get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
tree_map = {} tree_map: _TreeMap = {}
parent_map = {} parent_map: Dict[int, int] = {}
for r in range(nslave): for r in range(n_workers):
tree_map[r] = self.get_neighbor(r, nslave) tree_map[r] = self.get_neighbor(r, n_workers)
parent_map[r] = (r + 1) // 2 - 1 parent_map[r] = (r + 1) // 2 - 1
return tree_map, parent_map return tree_map, parent_map
def find_share_ring(self, tree_map, parent_map, r): def find_share_ring(
self, tree_map: _TreeMap, parent_map: Dict[int, int], r: int
) -> List[int]:
""" """
get a ring structure that tends to share nodes with the tree get a ring structure that tends to share nodes with the tree
return a list starting from r return a list starting from r
@ -225,63 +243,65 @@ class RabitTracker(object):
rlst += vlst rlst += vlst
return rlst return rlst
def get_ring(self, tree_map, parent_map): def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap:
""" """
get a ring connection used to recover local data get a ring connection used to recover local data
""" """
assert parent_map[0] == -1 assert parent_map[0] == -1
rlst = self.find_share_ring(tree_map, parent_map, 0) rlst = self.find_share_ring(tree_map, parent_map, 0)
assert len(rlst) == len(tree_map) assert len(rlst) == len(tree_map)
ring_map = {} ring_map: _RingMap = {}
nslave = len(tree_map) n_workers = len(tree_map)
for r in range(nslave): for r in range(n_workers):
rprev = (r + nslave - 1) % nslave rprev = (r + n_workers - 1) % n_workers
rnext = (r + 1) % nslave rnext = (r + 1) % n_workers
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map return ring_map
def get_link_map(self, nslave): def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]:
""" """
get the link map, this is a bit hacky, call for better algorithm get the link map, this is a bit hacky, call for better algorithm
to place similar nodes together to place similar nodes together
""" """
tree_map, parent_map = self.get_tree(nslave) tree_map, parent_map = self.get_tree(n_workers)
ring_map = self.get_ring(tree_map, parent_map) ring_map = self.get_ring(tree_map, parent_map)
rmap = {0: 0} rmap = {0: 0}
k = 0 k = 0
for i in range(nslave - 1): for i in range(n_workers - 1):
k = ring_map[k][1] k = ring_map[k][1]
rmap[k] = i + 1 rmap[k] = i + 1
ring_map_ = {} ring_map_: _RingMap = {}
tree_map_ = {} tree_map_: _TreeMap = {}
parent_map_ = {} parent_map_: Dict[int, int] = {}
for k, v in ring_map.items(): for k, v in ring_map.items():
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]]) ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
for k, v in tree_map.items(): for k, tree_nodes in tree_map.items():
tree_map_[rmap[k]] = [rmap[x] for x in v] tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes]
for k, v in parent_map.items(): for k, parent in parent_map.items():
if k != 0: if k != 0:
parent_map_[rmap[k]] = rmap[v] parent_map_[rmap[k]] = rmap[parent]
else: else:
parent_map_[rmap[k]] = -1 parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_ return tree_map_, parent_map_, ring_map_
def accept_slaves(self, nslave): def accept_workers(self, n_workers: int) -> None:
# set of nodes that finishs the job # set of nodes that finishs the job
shutdown = {} shutdown: Dict[int, WorkerEntry] = {}
# set of nodes that is waiting for connections # set of nodes that is waiting for connections
wait_conn = {} wait_conn: Dict[int, WorkerEntry] = {}
# maps job id to rank # maps job id to rank
job_map = {} job_map: Dict[str, int] = {}
# list of workers that is pending to be assigned rank # list of workers that is pending to be assigned rank
pending = [] pending: List[WorkerEntry] = []
# lazy initialize tree_map # lazy initialize tree_map
tree_map = None tree_map = None
while len(shutdown) != nslave: start_time = time.time()
while len(shutdown) != n_workers:
fd, s_addr = self.sock.accept() fd, s_addr = self.sock.accept()
s = SlaveEntry(fd, s_addr) s = WorkerEntry(fd, s_addr)
if s.cmd == 'print': if s.cmd == 'print':
msg = s.sock.recvstr() msg = s.sock.recvstr()
# On dask we use print to avoid setting global verbosity. # On dask we use print to avoid setting global verbosity.
@ -297,16 +317,16 @@ class RabitTracker(object):
logging.debug('Received %s signal from %d', s.cmd, s.rank) logging.debug('Received %s signal from %d', s.cmd, s.rank)
continue continue
assert s.cmd in ("start", "recover") assert s.cmd in ("start", "recover")
# lazily initialize the slaves # lazily initialize the workers
if tree_map is None: if tree_map is None:
assert s.cmd == 'start' assert s.cmd == 'start'
if s.world_size > 0: if s.world_size > 0:
nslave = s.world_size n_workers = s.world_size
tree_map, parent_map, ring_map = self.get_link_map(nslave) tree_map, parent_map, ring_map = self.get_link_map(n_workers)
# set of nodes that is pending for getting up # set of nodes that is pending for getting up
todo_nodes = list(range(nslave)) todo_nodes = list(range(n_workers))
else: else:
assert s.world_size in (-1, nslave) assert s.world_size in (-1, n_workers)
if s.cmd == 'recover': if s.cmd == 'recover':
assert s.rank >= 0 assert s.rank >= 0
@ -327,34 +347,35 @@ class RabitTracker(object):
logging.debug('Received %s signal from %s; assign rank %d', logging.debug('Received %s signal from %s; assign rank %d',
s.cmd, s.host, s.rank) s.cmd, s.host, s.rank)
if not todo_nodes: if not todo_nodes:
logging.info('@tracker All of %d nodes getting started', nslave) logging.info('@tracker All of %d nodes getting started', n_workers)
self.start_time = time.time()
else: else:
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
logging.debug('Received %s signal from %d', s.cmd, s.rank) logging.debug('Received %s signal from %d', s.cmd, s.rank)
if s.wait_accept > 0: if s.wait_accept > 0:
wait_conn[rank] = s wait_conn[rank] = s
logging.info('@tracker All nodes finishes job') logging.info('@tracker All nodes finishes job')
self.end_time = time.time() end_time = time.time()
logging.info('@tracker %s secs between node start and job finish', logging.info(
str(self.end_time - self.start_time)) '@tracker %s secs between node start and job finish',
str(end_time - start_time)
)
def start(self, nslave): def start(self, n_workers: int) -> None:
def run(): def run() -> None:
self.accept_slaves(nslave) self.accept_workers(n_workers)
self.thread = Thread(target=run, args=(), daemon=True) self.thread = Thread(target=run, args=(), daemon=True)
self.thread.start() self.thread.start()
def join(self): def join(self) -> None:
while self.thread.is_alive(): while self.thread is not None and self.thread.is_alive():
self.thread.join(100) self.thread.join(100)
def alive(self): def alive(self) -> bool:
return self.thread.is_alive() return self.thread is not None and self.thread.is_alive()
def get_host_ip(hostIP=None): def get_host_ip(hostIP: Optional[str] = None) -> str:
if hostIP is None or hostIP == 'auto': if hostIP is None or hostIP == 'auto':
hostIP = 'ip' hostIP = 'ip'
@ -374,10 +395,12 @@ def get_host_ip(hostIP=None):
# doesn't have to be reachable # doesn't have to be reachable
s.connect(('10.255.255.255', 1)) s.connect(('10.255.255.255', 1))
hostIP = s.getsockname()[0] hostIP = s.getsockname()[0]
assert hostIP is not None
return hostIP return hostIP
def start_rabit_tracker(args): def start_rabit_tracker(args: argparse.Namespace) -> None:
"""Standalone function to start rabit tracker. """Standalone function to start rabit tracker.
Parameters Parameters
@ -386,8 +409,8 @@ def start_rabit_tracker(args):
""" """
envs = {'DMLC_NUM_WORKER': args.num_workers, envs = {'DMLC_NUM_WORKER': args.num_workers,
'DMLC_NUM_SERVER': args.num_servers} 'DMLC_NUM_SERVER': args.num_servers}
rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), nslave=args.num_workers) rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), n_workers=args.num_workers)
envs.update(rabit.slave_envs()) envs.update(rabit.worker_envs())
rabit.start(args.num_workers) rabit.start(args.num_workers)
sys.stdout.write('DMLC_TRACKER_ENV_START\n') sys.stdout.write('DMLC_TRACKER_ENV_START\n')
# simply write configuration to stdout # simply write configuration to stdout
@ -398,13 +421,15 @@ def start_rabit_tracker(args):
rabit.join() rabit.join()
def main(): def main() -> None:
"""Main function if tracker is executed in standalone mode.""" """Main function if tracker is executed in standalone mode."""
parser = argparse.ArgumentParser(description='Rabit Tracker start.') parser = argparse.ArgumentParser(description='Rabit Tracker start.')
parser.add_argument('--num-workers', required=True, type=int, parser.add_argument('--num-workers', required=True, type=int,
help='Number of worker proccess to be launched.') help='Number of worker proccess to be launched.')
parser.add_argument('--num-servers', default=0, type=int, parser.add_argument(
help='Number of server process to be launched. Only used in PS jobs.') '--num-servers', default=0, type=int,
help='Number of server process to be launched. Only used in PS jobs.'
)
parser.add_argument('--host-ip', default=None, type=str, parser.add_argument('--host-ip', default=None, type=str,
help=('Host IP addressed, this is only needed ' + help=('Host IP addressed, this is only needed ' +
'if the host IP cannot be automatically guessed.')) 'if the host IP cannot be automatically guessed.'))

View File

@ -10,7 +10,7 @@ if sys.platform.startswith("win"):
def test_rabit_tracker(): def test_rabit_tracker():
tracker = RabitTracker(hostIP='127.0.0.1', nslave=1) tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1)
tracker.start(1) tracker.start(1)
rabit_env = [ rabit_env = [
str.encode('DMLC_TRACKER_URI=127.0.0.1'), str.encode('DMLC_TRACKER_URI=127.0.0.1'),