Add typehint to tracker. (#7338)
This commit is contained in:
parent
5ff210ed75
commit
f53da412aa
7
Makefile
7
Makefile
@ -93,11 +93,14 @@ mypy:
|
|||||||
cd python-package; \
|
cd python-package; \
|
||||||
mypy ./xgboost/dask.py && \
|
mypy ./xgboost/dask.py && \
|
||||||
mypy ./xgboost/rabit.py && \
|
mypy ./xgboost/rabit.py && \
|
||||||
|
mypy ./xgboost/tracker.py && \
|
||||||
|
mypy ./xgboost/sklearn.py && \
|
||||||
mypy ../demo/guide-python/external_memory.py && \
|
mypy ../demo/guide-python/external_memory.py && \
|
||||||
|
mypy ../demo/guide-python/categorical.py && \
|
||||||
|
mypy ../demo/guide-python/cat_in_the_dat.py && \
|
||||||
mypy ../tests/python-gpu/test_gpu_with_dask.py && \
|
mypy ../tests/python-gpu/test_gpu_with_dask.py && \
|
||||||
mypy ../tests/python/test_data_iterator.py && \
|
mypy ../tests/python/test_data_iterator.py && \
|
||||||
mypy ../tests/python-gpu/test_gpu_data_iterator.py && \
|
mypy ../tests/python-gpu/test_gpu_data_iterator.py || exit 1; \
|
||||||
mypy ./xgboost/sklearn.py || exit 1; \
|
|
||||||
mypy . || true ;
|
mypy . || true ;
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
|
|||||||
@ -138,10 +138,10 @@ def _multi_lock() -> Any:
|
|||||||
|
|
||||||
def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
||||||
"""Start Rabit tracker """
|
"""Start Rabit tracker """
|
||||||
env = {'DMLC_NUM_WORKER': n_workers}
|
env: Dict[str, Union[int, str]] = {'DMLC_NUM_WORKER': n_workers}
|
||||||
host = get_host_ip('auto')
|
host = get_host_ip('auto')
|
||||||
rabit_context = RabitTracker(hostIP=host, nslave=n_workers, use_logger=False)
|
rabit_context = RabitTracker(hostIP=host, n_workers=n_workers, use_logger=False)
|
||||||
env.update(rabit_context.slave_envs())
|
env.update(rabit_context.worker_envs())
|
||||||
|
|
||||||
rabit_context.start(n_workers)
|
rabit_context.start(n_workers)
|
||||||
thread = Thread(target=rabit_context.join)
|
thread = Thread(target=rabit_context.join)
|
||||||
|
|||||||
@ -13,16 +13,21 @@ from threading import Thread
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from typing import Dict, List, Tuple, Union, Optional
|
||||||
|
|
||||||
class ExSocket(object):
|
_RingMap = Dict[int, Tuple[int, int]]
|
||||||
|
_TreeMap = Dict[int, List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class ExSocket:
|
||||||
"""
|
"""
|
||||||
Extension of socket to handle recv and send of special data
|
Extension of socket to handle recv and send of special data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sock):
|
def __init__(self, sock: socket.socket) -> None:
|
||||||
self.sock = sock
|
self.sock = sock
|
||||||
|
|
||||||
def recvall(self, nbytes):
|
def recvall(self, nbytes: int) -> bytes:
|
||||||
res = []
|
res = []
|
||||||
nread = 0
|
nread = 0
|
||||||
while nread < nbytes:
|
while nread < nbytes:
|
||||||
@ -31,17 +36,17 @@ class ExSocket(object):
|
|||||||
res.append(chunk)
|
res.append(chunk)
|
||||||
return b''.join(res)
|
return b''.join(res)
|
||||||
|
|
||||||
def recvint(self):
|
def recvint(self) -> int:
|
||||||
return struct.unpack('@i', self.recvall(4))[0]
|
return struct.unpack('@i', self.recvall(4))[0]
|
||||||
|
|
||||||
def sendint(self, n):
|
def sendint(self, n: int) -> None:
|
||||||
self.sock.sendall(struct.pack('@i', n))
|
self.sock.sendall(struct.pack('@i', n))
|
||||||
|
|
||||||
def sendstr(self, s):
|
def sendstr(self, s: str) -> None:
|
||||||
self.sendint(len(s))
|
self.sendint(len(s))
|
||||||
self.sock.sendall(s.encode())
|
self.sock.sendall(s.encode())
|
||||||
|
|
||||||
def recvstr(self):
|
def recvstr(self) -> str:
|
||||||
slen = self.recvint()
|
slen = self.recvint()
|
||||||
return self.recvall(slen).decode()
|
return self.recvall(slen).decode()
|
||||||
|
|
||||||
@ -50,37 +55,44 @@ class ExSocket(object):
|
|||||||
kMagic = 0xff99
|
kMagic = 0xff99
|
||||||
|
|
||||||
|
|
||||||
def get_some_ip(host):
|
def get_some_ip(host: str) -> str:
|
||||||
return socket.getaddrinfo(host, None)[0][4][0]
|
return socket.getaddrinfo(host, None)[0][4][0]
|
||||||
|
|
||||||
|
|
||||||
def get_family(addr):
|
def get_family(addr: str) -> int:
|
||||||
return socket.getaddrinfo(addr, None)[0][0]
|
return socket.getaddrinfo(addr, None)[0][0]
|
||||||
|
|
||||||
|
|
||||||
class SlaveEntry(object):
|
class WorkerEntry:
|
||||||
def __init__(self, sock, s_addr):
|
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
|
||||||
slave = ExSocket(sock)
|
worker = ExSocket(sock)
|
||||||
self.sock = slave
|
self.sock = worker
|
||||||
self.host = get_some_ip(s_addr[0])
|
self.host = get_some_ip(s_addr[0])
|
||||||
magic = slave.recvint()
|
magic = worker.recvint()
|
||||||
assert magic == kMagic, f"invalid magic number={magic} from {self.host}"
|
assert magic == kMagic, f"invalid magic number={magic} from {self.host}"
|
||||||
slave.sendint(kMagic)
|
worker.sendint(kMagic)
|
||||||
self.rank = slave.recvint()
|
self.rank = worker.recvint()
|
||||||
self.world_size = slave.recvint()
|
self.world_size = worker.recvint()
|
||||||
self.jobid = slave.recvstr()
|
self.jobid = worker.recvstr()
|
||||||
self.cmd = slave.recvstr()
|
self.cmd = worker.recvstr()
|
||||||
self.wait_accept = 0
|
self.wait_accept = 0
|
||||||
self.port = None
|
self.port: Optional[int] = None
|
||||||
|
|
||||||
def decide_rank(self, job_map):
|
def decide_rank(self, job_map: Dict[str, int]) -> int:
|
||||||
if self.rank >= 0:
|
if self.rank >= 0:
|
||||||
return self.rank
|
return self.rank
|
||||||
if self.jobid != 'NULL' and self.jobid in job_map:
|
if self.jobid != 'NULL' and self.jobid in job_map:
|
||||||
return job_map[self.jobid]
|
return job_map[self.jobid]
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
|
def assign_rank(
|
||||||
|
self,
|
||||||
|
rank: int,
|
||||||
|
wait_conn: Dict[int, "WorkerEntry"],
|
||||||
|
tree_map: _TreeMap,
|
||||||
|
parent_map: Dict[int, int],
|
||||||
|
ring_map: _RingMap,
|
||||||
|
) -> List[int]:
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
nnset = set(tree_map[rank])
|
nnset = set(tree_map[rank])
|
||||||
rprev, rnext = ring_map[rank]
|
rprev, rnext = ring_map[rank]
|
||||||
@ -120,7 +132,9 @@ class SlaveEntry(object):
|
|||||||
self.sock.sendint(len(badset) - len(conset))
|
self.sock.sendint(len(badset) - len(conset))
|
||||||
for r in conset:
|
for r in conset:
|
||||||
self.sock.sendstr(wait_conn[r].host)
|
self.sock.sendstr(wait_conn[r].host)
|
||||||
self.sock.sendint(wait_conn[r].port)
|
port = wait_conn[r].port
|
||||||
|
assert port is not None
|
||||||
|
self.sock.sendint(port)
|
||||||
self.sock.sendint(r)
|
self.sock.sendint(r)
|
||||||
nerr = self.sock.recvint()
|
nerr = self.sock.recvint()
|
||||||
if nerr != 0:
|
if nerr != 0:
|
||||||
@ -138,13 +152,17 @@ class SlaveEntry(object):
|
|||||||
return rmset
|
return rmset
|
||||||
|
|
||||||
|
|
||||||
class RabitTracker(object):
|
class RabitTracker:
|
||||||
"""
|
"""
|
||||||
tracker for rabit
|
tracker for rabit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hostIP, nslave, port=9091, port_end=9999, use_logger: bool = True
|
self, hostIP: str,
|
||||||
|
n_workers: int,
|
||||||
|
port: int = 9091,
|
||||||
|
port_end: int = 9999,
|
||||||
|
use_logger: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A Python implementation of RABIT tracker.
|
"""A Python implementation of RABIT tracker.
|
||||||
|
|
||||||
@ -168,45 +186,45 @@ class RabitTracker(object):
|
|||||||
sock.listen(256)
|
sock.listen(256)
|
||||||
self.sock = sock
|
self.sock = sock
|
||||||
self.hostIP = hostIP
|
self.hostIP = hostIP
|
||||||
self.thread = None
|
self.thread: Optional[Thread] = None
|
||||||
self.start_time = None
|
self.n_workers = n_workers
|
||||||
self.end_time = None
|
|
||||||
self.nslave = nslave
|
|
||||||
self._use_logger = use_logger
|
self._use_logger = use_logger
|
||||||
logging.info('start listen on %s:%d', hostIP, self.port)
|
logging.info('start listen on %s:%d', hostIP, self.port)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self) -> None:
|
||||||
self.sock.close()
|
self.sock.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_neighbor(rank, nslave):
|
def get_neighbor(rank: int, n_workers: int) -> List[int]:
|
||||||
rank = rank + 1
|
rank = rank + 1
|
||||||
ret = []
|
ret = []
|
||||||
if rank > 1:
|
if rank > 1:
|
||||||
ret.append(rank // 2 - 1)
|
ret.append(rank // 2 - 1)
|
||||||
if rank * 2 - 1 < nslave:
|
if rank * 2 - 1 < n_workers:
|
||||||
ret.append(rank * 2 - 1)
|
ret.append(rank * 2 - 1)
|
||||||
if rank * 2 < nslave:
|
if rank * 2 < n_workers:
|
||||||
ret.append(rank * 2)
|
ret.append(rank * 2)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def slave_envs(self):
|
def worker_envs(self) -> Dict[str, Union[str, int]]:
|
||||||
"""
|
"""
|
||||||
get enviroment variables for slaves
|
get enviroment variables for workers
|
||||||
can be passed in as args or envs
|
can be passed in as args or envs
|
||||||
"""
|
"""
|
||||||
return {'DMLC_TRACKER_URI': self.hostIP,
|
return {'DMLC_TRACKER_URI': self.hostIP,
|
||||||
'DMLC_TRACKER_PORT': self.port}
|
'DMLC_TRACKER_PORT': self.port}
|
||||||
|
|
||||||
def get_tree(self, nslave):
|
def get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
|
||||||
tree_map = {}
|
tree_map: _TreeMap = {}
|
||||||
parent_map = {}
|
parent_map: Dict[int, int] = {}
|
||||||
for r in range(nslave):
|
for r in range(n_workers):
|
||||||
tree_map[r] = self.get_neighbor(r, nslave)
|
tree_map[r] = self.get_neighbor(r, n_workers)
|
||||||
parent_map[r] = (r + 1) // 2 - 1
|
parent_map[r] = (r + 1) // 2 - 1
|
||||||
return tree_map, parent_map
|
return tree_map, parent_map
|
||||||
|
|
||||||
def find_share_ring(self, tree_map, parent_map, r):
|
def find_share_ring(
|
||||||
|
self, tree_map: _TreeMap, parent_map: Dict[int, int], r: int
|
||||||
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
get a ring structure that tends to share nodes with the tree
|
get a ring structure that tends to share nodes with the tree
|
||||||
return a list starting from r
|
return a list starting from r
|
||||||
@ -225,63 +243,65 @@ class RabitTracker(object):
|
|||||||
rlst += vlst
|
rlst += vlst
|
||||||
return rlst
|
return rlst
|
||||||
|
|
||||||
def get_ring(self, tree_map, parent_map):
|
def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap:
|
||||||
"""
|
"""
|
||||||
get a ring connection used to recover local data
|
get a ring connection used to recover local data
|
||||||
"""
|
"""
|
||||||
assert parent_map[0] == -1
|
assert parent_map[0] == -1
|
||||||
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
||||||
assert len(rlst) == len(tree_map)
|
assert len(rlst) == len(tree_map)
|
||||||
ring_map = {}
|
ring_map: _RingMap = {}
|
||||||
nslave = len(tree_map)
|
n_workers = len(tree_map)
|
||||||
for r in range(nslave):
|
for r in range(n_workers):
|
||||||
rprev = (r + nslave - 1) % nslave
|
rprev = (r + n_workers - 1) % n_workers
|
||||||
rnext = (r + 1) % nslave
|
rnext = (r + 1) % n_workers
|
||||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||||
return ring_map
|
return ring_map
|
||||||
|
|
||||||
def get_link_map(self, nslave):
|
def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]:
|
||||||
"""
|
"""
|
||||||
get the link map, this is a bit hacky, call for better algorithm
|
get the link map, this is a bit hacky, call for better algorithm
|
||||||
to place similar nodes together
|
to place similar nodes together
|
||||||
"""
|
"""
|
||||||
tree_map, parent_map = self.get_tree(nslave)
|
tree_map, parent_map = self.get_tree(n_workers)
|
||||||
ring_map = self.get_ring(tree_map, parent_map)
|
ring_map = self.get_ring(tree_map, parent_map)
|
||||||
rmap = {0: 0}
|
rmap = {0: 0}
|
||||||
k = 0
|
k = 0
|
||||||
for i in range(nslave - 1):
|
for i in range(n_workers - 1):
|
||||||
k = ring_map[k][1]
|
k = ring_map[k][1]
|
||||||
rmap[k] = i + 1
|
rmap[k] = i + 1
|
||||||
|
|
||||||
ring_map_ = {}
|
ring_map_: _RingMap = {}
|
||||||
tree_map_ = {}
|
tree_map_: _TreeMap = {}
|
||||||
parent_map_ = {}
|
parent_map_: Dict[int, int] = {}
|
||||||
for k, v in ring_map.items():
|
for k, v in ring_map.items():
|
||||||
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
|
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
|
||||||
for k, v in tree_map.items():
|
for k, tree_nodes in tree_map.items():
|
||||||
tree_map_[rmap[k]] = [rmap[x] for x in v]
|
tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes]
|
||||||
for k, v in parent_map.items():
|
for k, parent in parent_map.items():
|
||||||
if k != 0:
|
if k != 0:
|
||||||
parent_map_[rmap[k]] = rmap[v]
|
parent_map_[rmap[k]] = rmap[parent]
|
||||||
else:
|
else:
|
||||||
parent_map_[rmap[k]] = -1
|
parent_map_[rmap[k]] = -1
|
||||||
return tree_map_, parent_map_, ring_map_
|
return tree_map_, parent_map_, ring_map_
|
||||||
|
|
||||||
def accept_slaves(self, nslave):
|
def accept_workers(self, n_workers: int) -> None:
|
||||||
# set of nodes that finishs the job
|
# set of nodes that finishs the job
|
||||||
shutdown = {}
|
shutdown: Dict[int, WorkerEntry] = {}
|
||||||
# set of nodes that is waiting for connections
|
# set of nodes that is waiting for connections
|
||||||
wait_conn = {}
|
wait_conn: Dict[int, WorkerEntry] = {}
|
||||||
# maps job id to rank
|
# maps job id to rank
|
||||||
job_map = {}
|
job_map: Dict[str, int] = {}
|
||||||
# list of workers that is pending to be assigned rank
|
# list of workers that is pending to be assigned rank
|
||||||
pending = []
|
pending: List[WorkerEntry] = []
|
||||||
# lazy initialize tree_map
|
# lazy initialize tree_map
|
||||||
tree_map = None
|
tree_map = None
|
||||||
|
|
||||||
while len(shutdown) != nslave:
|
start_time = time.time()
|
||||||
|
|
||||||
|
while len(shutdown) != n_workers:
|
||||||
fd, s_addr = self.sock.accept()
|
fd, s_addr = self.sock.accept()
|
||||||
s = SlaveEntry(fd, s_addr)
|
s = WorkerEntry(fd, s_addr)
|
||||||
if s.cmd == 'print':
|
if s.cmd == 'print':
|
||||||
msg = s.sock.recvstr()
|
msg = s.sock.recvstr()
|
||||||
# On dask we use print to avoid setting global verbosity.
|
# On dask we use print to avoid setting global verbosity.
|
||||||
@ -297,16 +317,16 @@ class RabitTracker(object):
|
|||||||
logging.debug('Received %s signal from %d', s.cmd, s.rank)
|
logging.debug('Received %s signal from %d', s.cmd, s.rank)
|
||||||
continue
|
continue
|
||||||
assert s.cmd in ("start", "recover")
|
assert s.cmd in ("start", "recover")
|
||||||
# lazily initialize the slaves
|
# lazily initialize the workers
|
||||||
if tree_map is None:
|
if tree_map is None:
|
||||||
assert s.cmd == 'start'
|
assert s.cmd == 'start'
|
||||||
if s.world_size > 0:
|
if s.world_size > 0:
|
||||||
nslave = s.world_size
|
n_workers = s.world_size
|
||||||
tree_map, parent_map, ring_map = self.get_link_map(nslave)
|
tree_map, parent_map, ring_map = self.get_link_map(n_workers)
|
||||||
# set of nodes that is pending for getting up
|
# set of nodes that is pending for getting up
|
||||||
todo_nodes = list(range(nslave))
|
todo_nodes = list(range(n_workers))
|
||||||
else:
|
else:
|
||||||
assert s.world_size in (-1, nslave)
|
assert s.world_size in (-1, n_workers)
|
||||||
if s.cmd == 'recover':
|
if s.cmd == 'recover':
|
||||||
assert s.rank >= 0
|
assert s.rank >= 0
|
||||||
|
|
||||||
@ -327,34 +347,35 @@ class RabitTracker(object):
|
|||||||
logging.debug('Received %s signal from %s; assign rank %d',
|
logging.debug('Received %s signal from %s; assign rank %d',
|
||||||
s.cmd, s.host, s.rank)
|
s.cmd, s.host, s.rank)
|
||||||
if not todo_nodes:
|
if not todo_nodes:
|
||||||
logging.info('@tracker All of %d nodes getting started', nslave)
|
logging.info('@tracker All of %d nodes getting started', n_workers)
|
||||||
self.start_time = time.time()
|
|
||||||
else:
|
else:
|
||||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||||
logging.debug('Received %s signal from %d', s.cmd, s.rank)
|
logging.debug('Received %s signal from %d', s.cmd, s.rank)
|
||||||
if s.wait_accept > 0:
|
if s.wait_accept > 0:
|
||||||
wait_conn[rank] = s
|
wait_conn[rank] = s
|
||||||
logging.info('@tracker All nodes finishes job')
|
logging.info('@tracker All nodes finishes job')
|
||||||
self.end_time = time.time()
|
end_time = time.time()
|
||||||
logging.info('@tracker %s secs between node start and job finish',
|
logging.info(
|
||||||
str(self.end_time - self.start_time))
|
'@tracker %s secs between node start and job finish',
|
||||||
|
str(end_time - start_time)
|
||||||
|
)
|
||||||
|
|
||||||
def start(self, nslave):
|
def start(self, n_workers: int) -> None:
|
||||||
def run():
|
def run() -> None:
|
||||||
self.accept_slaves(nslave)
|
self.accept_workers(n_workers)
|
||||||
|
|
||||||
self.thread = Thread(target=run, args=(), daemon=True)
|
self.thread = Thread(target=run, args=(), daemon=True)
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
|
||||||
def join(self):
|
def join(self) -> None:
|
||||||
while self.thread.is_alive():
|
while self.thread is not None and self.thread.is_alive():
|
||||||
self.thread.join(100)
|
self.thread.join(100)
|
||||||
|
|
||||||
def alive(self):
|
def alive(self) -> bool:
|
||||||
return self.thread.is_alive()
|
return self.thread is not None and self.thread.is_alive()
|
||||||
|
|
||||||
|
|
||||||
def get_host_ip(hostIP=None):
|
def get_host_ip(hostIP: Optional[str] = None) -> str:
|
||||||
if hostIP is None or hostIP == 'auto':
|
if hostIP is None or hostIP == 'auto':
|
||||||
hostIP = 'ip'
|
hostIP = 'ip'
|
||||||
|
|
||||||
@ -374,10 +395,12 @@ def get_host_ip(hostIP=None):
|
|||||||
# doesn't have to be reachable
|
# doesn't have to be reachable
|
||||||
s.connect(('10.255.255.255', 1))
|
s.connect(('10.255.255.255', 1))
|
||||||
hostIP = s.getsockname()[0]
|
hostIP = s.getsockname()[0]
|
||||||
|
|
||||||
|
assert hostIP is not None
|
||||||
return hostIP
|
return hostIP
|
||||||
|
|
||||||
|
|
||||||
def start_rabit_tracker(args):
|
def start_rabit_tracker(args: argparse.Namespace) -> None:
|
||||||
"""Standalone function to start rabit tracker.
|
"""Standalone function to start rabit tracker.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -386,8 +409,8 @@ def start_rabit_tracker(args):
|
|||||||
"""
|
"""
|
||||||
envs = {'DMLC_NUM_WORKER': args.num_workers,
|
envs = {'DMLC_NUM_WORKER': args.num_workers,
|
||||||
'DMLC_NUM_SERVER': args.num_servers}
|
'DMLC_NUM_SERVER': args.num_servers}
|
||||||
rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), nslave=args.num_workers)
|
rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), n_workers=args.num_workers)
|
||||||
envs.update(rabit.slave_envs())
|
envs.update(rabit.worker_envs())
|
||||||
rabit.start(args.num_workers)
|
rabit.start(args.num_workers)
|
||||||
sys.stdout.write('DMLC_TRACKER_ENV_START\n')
|
sys.stdout.write('DMLC_TRACKER_ENV_START\n')
|
||||||
# simply write configuration to stdout
|
# simply write configuration to stdout
|
||||||
@ -398,13 +421,15 @@ def start_rabit_tracker(args):
|
|||||||
rabit.join()
|
rabit.join()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
"""Main function if tracker is executed in standalone mode."""
|
"""Main function if tracker is executed in standalone mode."""
|
||||||
parser = argparse.ArgumentParser(description='Rabit Tracker start.')
|
parser = argparse.ArgumentParser(description='Rabit Tracker start.')
|
||||||
parser.add_argument('--num-workers', required=True, type=int,
|
parser.add_argument('--num-workers', required=True, type=int,
|
||||||
help='Number of worker proccess to be launched.')
|
help='Number of worker proccess to be launched.')
|
||||||
parser.add_argument('--num-servers', default=0, type=int,
|
parser.add_argument(
|
||||||
help='Number of server process to be launched. Only used in PS jobs.')
|
'--num-servers', default=0, type=int,
|
||||||
|
help='Number of server process to be launched. Only used in PS jobs.'
|
||||||
|
)
|
||||||
parser.add_argument('--host-ip', default=None, type=str,
|
parser.add_argument('--host-ip', default=None, type=str,
|
||||||
help=('Host IP addressed, this is only needed ' +
|
help=('Host IP addressed, this is only needed ' +
|
||||||
'if the host IP cannot be automatically guessed.'))
|
'if the host IP cannot be automatically guessed.'))
|
||||||
|
|||||||
@ -10,7 +10,7 @@ if sys.platform.startswith("win"):
|
|||||||
|
|
||||||
|
|
||||||
def test_rabit_tracker():
|
def test_rabit_tracker():
|
||||||
tracker = RabitTracker(hostIP='127.0.0.1', nslave=1)
|
tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1)
|
||||||
tracker.start(1)
|
tracker.start(1)
|
||||||
rabit_env = [
|
rabit_env = [
|
||||||
str.encode('DMLC_TRACKER_URI=127.0.0.1'),
|
str.encode('DMLC_TRACKER_URI=127.0.0.1'),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user