Merge commit 'a16289b2047a7c2ec36667f6031dbb648e4d2caa'
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Tracker script for rabit
|
||||
Implements the tracker control protocol
|
||||
Implements the tracker control protocol
|
||||
- start rabit jobs
|
||||
- help nodes to establish links with each other
|
||||
|
||||
@@ -19,13 +19,13 @@ from threading import Thread
|
||||
"""
|
||||
Extension of socket to handle recv and send of special data
|
||||
"""
|
||||
class ExSocket:
|
||||
class ExSocket:
|
||||
def __init__(self, sock):
|
||||
self.sock = sock
|
||||
def recvall(self, nbytes):
|
||||
res = []
|
||||
sock = self.sock
|
||||
nread = 0
|
||||
nread = 0
|
||||
while nread < nbytes:
|
||||
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
||||
nread += len(chunk)
|
||||
@@ -106,7 +106,7 @@ class SlaveEntry:
|
||||
for r in conset:
|
||||
self.sock.sendstr(wait_conn[r].host)
|
||||
self.sock.sendint(wait_conn[r].port)
|
||||
self.sock.sendint(r)
|
||||
self.sock.sendint(r)
|
||||
nerr = self.sock.recvint()
|
||||
if nerr != 0:
|
||||
continue
|
||||
@@ -121,7 +121,7 @@ class SlaveEntry:
|
||||
wait_conn.pop(r, None)
|
||||
self.wait_accept = len(badset) - len(conset)
|
||||
return rmset
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
@@ -132,7 +132,7 @@ class Tracker:
|
||||
break
|
||||
except socket.error:
|
||||
continue
|
||||
sock.listen(16)
|
||||
sock.listen(128)
|
||||
self.sock = sock
|
||||
self.verbose = verbose
|
||||
if hostIP == 'auto':
|
||||
@@ -145,7 +145,7 @@ class Tracker:
|
||||
"""
|
||||
get enviroment variables for slaves
|
||||
can be passed in as args or envs
|
||||
"""
|
||||
"""
|
||||
if self.hostIP == 'dns':
|
||||
host = socket.gethostname()
|
||||
elif self.hostIP == 'ip':
|
||||
@@ -153,14 +153,14 @@ class Tracker:
|
||||
else:
|
||||
host = self.hostIP
|
||||
return {'rabit_tracker_uri': host,
|
||||
'rabit_tracker_port': self.port}
|
||||
'rabit_tracker_port': self.port}
|
||||
def get_neighbor(self, rank, nslave):
|
||||
rank = rank + 1
|
||||
ret = []
|
||||
if rank > 1:
|
||||
ret.append(rank / 2 - 1)
|
||||
if rank * 2 - 1 < nslave:
|
||||
ret.append(rank * 2 - 1)
|
||||
ret.append(rank * 2 - 1)
|
||||
if rank * 2 < nslave:
|
||||
ret.append(rank * 2)
|
||||
return ret
|
||||
@@ -198,10 +198,10 @@ class Tracker:
|
||||
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
||||
assert len(rlst) == len(tree_map)
|
||||
ring_map = {}
|
||||
nslave = len(tree_map)
|
||||
nslave = len(tree_map)
|
||||
for r in range(nslave):
|
||||
rprev = (r + nslave - 1) % nslave
|
||||
rnext = (r + 1) % nslave
|
||||
rnext = (r + 1) % nslave
|
||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||
return ring_map
|
||||
|
||||
@@ -231,7 +231,7 @@ class Tracker:
|
||||
else:
|
||||
parent_map_[rmap[k]] = -1
|
||||
return tree_map_, parent_map_, ring_map_
|
||||
|
||||
|
||||
def handle_print(self,slave, msg):
|
||||
sys.stdout.write(msg)
|
||||
|
||||
@@ -253,14 +253,14 @@ class Tracker:
|
||||
pending = []
|
||||
# lazy initialize tree_map
|
||||
tree_map = None
|
||||
|
||||
|
||||
while len(shutdown) != nslave:
|
||||
fd, s_addr = self.sock.accept()
|
||||
s = SlaveEntry(fd, s_addr)
|
||||
if s.cmd == 'print':
|
||||
msg = s.sock.recvstr()
|
||||
self.handle_print(s, msg)
|
||||
continue
|
||||
continue
|
||||
if s.cmd == 'shutdown':
|
||||
assert s.rank >= 0 and s.rank not in shutdown
|
||||
assert s.rank not in wait_conn
|
||||
@@ -280,12 +280,12 @@ class Tracker:
|
||||
assert s.world_size == -1 or s.world_size == nslave
|
||||
if s.cmd == 'recover':
|
||||
assert s.rank >= 0
|
||||
|
||||
|
||||
rank = s.decide_rank(job_map)
|
||||
# batch assignment of ranks
|
||||
if rank == -1:
|
||||
assert len(todo_nodes) != 0
|
||||
pending.append(s)
|
||||
pending.append(s)
|
||||
if len(pending) == len(todo_nodes):
|
||||
pending.sort(key = lambda x : x.host)
|
||||
for s in pending:
|
||||
|
||||
Reference in New Issue
Block a user