28ca7be add linear readme ca4b20f add linear readme 1133628 add linear readme 6a11676 update docs a607047 Update build.sh 2c1cfd8 complete yarn 4f28e32 change formater 2fbda81 fix stdin input 3258bcf checkin yarn master 67ebf81 allow setup from env variables 9b6bf57 fix hdfs 395d5c2 add make system 88ce767 refactor io, initial hdfs file access need test 19be870 chgs a1bd3c6 Merge branch 'master' of ssh://github.com/tqchen/rabit 1a573f9 introduce input split 29476f1 fix timer issue git-subtree-dir: subtree/rabit git-subtree-split: 28ca7becbdf6503e6b1398588a969efb164c9701
277 lines
9.3 KiB
Python
277 lines
9.3 KiB
Python
"""
|
|
Tracker script for rabit
|
|
Implements the tracker control protocol
|
|
- start rabit jobs
|
|
- help nodes to establish links with each other
|
|
|
|
Tianqi Chen
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import socket
|
|
import struct
|
|
import subprocess
|
|
import random
|
|
from threading import Thread
|
|
|
|
"""
|
|
Extension of socket to handle recv and send of special data
|
|
"""
|
|
class ExSocket:
|
|
def __init__(self, sock):
|
|
self.sock = sock
|
|
def recvall(self, nbytes):
|
|
res = []
|
|
sock = self.sock
|
|
nread = 0
|
|
while nread < nbytes:
|
|
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
|
nread += len(chunk)
|
|
res.append(chunk)
|
|
return ''.join(res)
|
|
def recvint(self):
|
|
return struct.unpack('@i', self.recvall(4))[0]
|
|
def sendint(self, n):
|
|
self.sock.sendall(struct.pack('@i', n))
|
|
def sendstr(self, s):
|
|
self.sendint(len(s))
|
|
self.sock.sendall(s)
|
|
def recvstr(self):
|
|
slen = self.recvint()
|
|
return self.recvall(slen)
|
|
|
|
# magic number used to verify existence of data
|
|
kMagic = 0xff99
|
|
|
|
class SlaveEntry:
|
|
def __init__(self, sock, s_addr):
|
|
slave = ExSocket(sock)
|
|
self.sock = slave
|
|
self.host = s_addr[0]
|
|
magic = slave.recvint()
|
|
assert magic == kMagic, 'invalid magic number=%d from %s' % (magic, s_addr[0])
|
|
slave.sendint(kMagic)
|
|
self.rank = slave.recvint()
|
|
self.world_size = slave.recvint()
|
|
self.jobid = slave.recvstr()
|
|
self.cmd = slave.recvstr()
|
|
|
|
def decide_rank(self, job_map):
|
|
if self.rank >= 0:
|
|
return self.rank
|
|
if self.jobid != 'NULL' and self.jobid in job_map:
|
|
return job_map[self.jobid]
|
|
return -1
|
|
|
|
def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
|
|
self.rank = rank
|
|
nnset = set(tree_map[rank])
|
|
rprev, rnext = ring_map[rank]
|
|
self.sock.sendint(rank)
|
|
# send parent rank
|
|
self.sock.sendint(parent_map[rank])
|
|
# send world size
|
|
self.sock.sendint(len(tree_map))
|
|
self.sock.sendint(len(nnset))
|
|
# send the rprev and next link
|
|
for r in nnset:
|
|
self.sock.sendint(r)
|
|
# send prev link
|
|
if rprev != -1 and rprev != rank:
|
|
nnset.add(rprev)
|
|
self.sock.sendint(rprev)
|
|
else:
|
|
self.sock.sendint(-1)
|
|
# send next link
|
|
if rnext != -1 and rnext != rank:
|
|
nnset.add(rnext)
|
|
self.sock.sendint(rnext)
|
|
else:
|
|
self.sock.sendint(-1)
|
|
while True:
|
|
ngood = self.sock.recvint()
|
|
goodset = set([])
|
|
for i in xrange(ngood):
|
|
goodset.add(self.sock.recvint())
|
|
assert goodset.issubset(nnset)
|
|
badset = nnset - goodset
|
|
conset = []
|
|
for r in badset:
|
|
if r in wait_conn:
|
|
conset.append(r)
|
|
self.sock.sendint(len(conset))
|
|
self.sock.sendint(len(badset) - len(conset))
|
|
for r in conset:
|
|
self.sock.sendstr(wait_conn[r].host)
|
|
self.sock.sendint(wait_conn[r].port)
|
|
self.sock.sendint(r)
|
|
nerr = self.sock.recvint()
|
|
if nerr != 0:
|
|
continue
|
|
self.port = self.sock.recvint()
|
|
rmset = []
|
|
# all connection was successuly setup
|
|
for r in conset:
|
|
wait_conn[r].wait_accept -= 1
|
|
if wait_conn[r].wait_accept == 0:
|
|
rmset.append(r)
|
|
for r in rmset:
|
|
wait_conn.pop(r, None)
|
|
self.wait_accept = len(badset) - len(conset)
|
|
return rmset
|
|
|
|
class Tracker:
|
|
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
for port in range(port, port_end):
|
|
try:
|
|
sock.bind(('', port))
|
|
self.port = port
|
|
break
|
|
except socket.error:
|
|
continue
|
|
sock.listen(16)
|
|
self.sock = sock
|
|
self.verbose = verbose
|
|
if hostIP == 'auto':
|
|
hostIP = 'dns'
|
|
self.hostIP = hostIP
|
|
self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1)
|
|
def __del__(self):
|
|
self.sock.close()
|
|
def slave_envs(self):
|
|
"""
|
|
get enviroment variables for slaves
|
|
can be passed in as args or envs
|
|
"""
|
|
if self.hostIP == 'dns':
|
|
host = socket.gethostname()
|
|
elif self.hostIP == 'ip':
|
|
host = socket.gethostbyname(socket.getfqdn())
|
|
else:
|
|
host = self.hostIP
|
|
return {'rabit_tracker_uri': host,
|
|
'rabit_tracker_port': self.port}
|
|
def get_neighbor(self, rank, nslave):
|
|
rank = rank + 1
|
|
ret = []
|
|
if rank > 1:
|
|
ret.append(rank / 2 - 1)
|
|
if rank * 2 - 1 < nslave:
|
|
ret.append(rank * 2 - 1)
|
|
if rank * 2 < nslave:
|
|
ret.append(rank * 2)
|
|
return ret
|
|
def get_tree(self, nslave):
|
|
tree_map = {}
|
|
parent_map = {}
|
|
for r in range(nslave):
|
|
tree_map[r] = self.get_neighbor(r, nslave)
|
|
parent_map[r] = (r + 1) / 2 - 1
|
|
return tree_map, parent_map
|
|
def find_share_ring(self, tree_map, parent_map, r):
|
|
"""
|
|
get a ring structure that tends to share nodes with the tree
|
|
return a list starting from r
|
|
"""
|
|
nset = set(tree_map[r])
|
|
cset = nset - set([parent_map[r]])
|
|
if len(cset) == 0:
|
|
return [r]
|
|
rlst = [r]
|
|
cnt = 0
|
|
for v in cset:
|
|
vlst = self.find_share_ring(tree_map, parent_map, v)
|
|
cnt += 1
|
|
if cnt == len(cset):
|
|
vlst.reverse()
|
|
rlst += vlst
|
|
return rlst
|
|
def get_ring(self, tree_map, parent_map):
|
|
"""
|
|
get a ring connection used to recover local data
|
|
"""
|
|
assert parent_map[0] == -1
|
|
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
|
assert len(rlst) == len(tree_map)
|
|
ring_map = {}
|
|
nslave = len(tree_map)
|
|
for r in range(nslave):
|
|
rprev = (r + nslave - 1) % nslave
|
|
rnext = (r + 1) % nslave
|
|
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
|
return ring_map
|
|
def handle_print(self,slave, msg):
|
|
sys.stdout.write(msg)
|
|
def log_print(self, msg, level):
|
|
if level == 1:
|
|
if self.verbose:
|
|
sys.stderr.write(msg + '\n')
|
|
else:
|
|
sys.stderr.write(msg + '\n')
|
|
def accept_slaves(self, nslave):
|
|
# set of nodes that finishs the job
|
|
shutdown = {}
|
|
# set of nodes that is waiting for connections
|
|
wait_conn = {}
|
|
# maps job id to rank
|
|
job_map = {}
|
|
# list of workers that is pending to be assigned rank
|
|
pending = []
|
|
# lazy initialize tree_map
|
|
tree_map = None
|
|
|
|
while len(shutdown) != nslave:
|
|
fd, s_addr = self.sock.accept()
|
|
s = SlaveEntry(fd, s_addr)
|
|
if s.cmd == 'print':
|
|
msg = s.sock.recvstr()
|
|
self.handle_print(s, msg)
|
|
continue
|
|
if s.cmd == 'shutdown':
|
|
assert s.rank >= 0 and s.rank not in shutdown
|
|
assert s.rank not in wait_conn
|
|
shutdown[s.rank] = s
|
|
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
|
continue
|
|
assert s.cmd == 'start' or s.cmd == 'recover'
|
|
# lazily initialize the slaves
|
|
if tree_map == None:
|
|
assert s.cmd == 'start'
|
|
if s.world_size > 0:
|
|
nslave = s.world_size
|
|
tree_map, parent_map = self.get_tree(nslave)
|
|
ring_map = self.get_ring(tree_map, parent_map)
|
|
# set of nodes that is pending for getting up
|
|
todo_nodes = range(nslave)
|
|
random.shuffle(todo_nodes)
|
|
else:
|
|
assert s.world_size == -1 or s.world_size == nslave
|
|
if s.cmd == 'recover':
|
|
assert s.rank >= 0
|
|
rank = s.decide_rank(job_map)
|
|
if rank == -1:
|
|
assert len(todo_nodes) != 0
|
|
rank = todo_nodes.pop(0)
|
|
if s.jobid != 'NULL':
|
|
job_map[s.jobid] = rank
|
|
if len(todo_nodes) == 0:
|
|
self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
|
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
|
if s.cmd != 'start':
|
|
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
|
else:
|
|
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
|
|
if s.wait_accept > 0:
|
|
wait_conn[rank] = s
|
|
self.log_print('@tracker All nodes finishes job', 2)
|
|
|
|
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
|
|
master = Tracker(verbose = verbose, hostIP = hostIP)
|
|
submit_thread = Thread(target = fun_submit, args = (nslave, args, master.slave_envs()))
|
|
submit_thread.daemon = True
|
|
submit_thread.start()
|
|
master.accept_slaves(nslave)
|
|
submit_thread.join()
|