Merge branch 'master' of ssh://github.com/tqchen/rabit

Conflicts:
	tracker/rabit_tracker.py
This commit is contained in:
tqchen
2015-03-11 13:35:35 -07:00
13 changed files with 440 additions and 64 deletions

View File

@@ -189,6 +189,7 @@ class Tracker:
vlst.reverse()
rlst += vlst
return rlst
def get_ring(self, tree_map, parent_map):
"""
get a ring connection used to recover local data
@@ -203,14 +204,44 @@ class Tracker:
rnext = (r + 1) % nslave
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map
def get_link_map(self, nslave):
"""
get the link map, this is a bit hacky, call for better algorithm
to place similar nodes together
"""
tree_map, parent_map = self.get_tree(nslave)
ring_map = self.get_ring(tree_map, parent_map)
rmap = {0 : 0}
k = 0
for i in range(nslave - 1):
k = ring_map[k][1]
rmap[k] = i + 1
ring_map_ = {}
tree_map_ = {}
parent_map_ ={}
for k, v in ring_map.items():
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
for k, v in tree_map.items():
tree_map_[rmap[k]] = [rmap[x] for x in v]
for k, v in parent_map.items():
if k != 0:
parent_map_[rmap[k]] = rmap[v]
else:
parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_
def handle_print(self,slave, msg):
sys.stdout.write(msg)
def log_print(self, msg, level):
if level == 1:
if self.verbose:
sys.stderr.write(msg + '\n')
else:
sys.stderr.write(msg + '\n')
def accept_slaves(self, nslave):
# set of nodes that finishs the job
shutdown = {}
@@ -242,31 +273,37 @@ class Tracker:
assert s.cmd == 'start'
if s.world_size > 0:
nslave = s.world_size
tree_map, parent_map = self.get_tree(nslave)
ring_map = self.get_ring(tree_map, parent_map)
tree_map, parent_map, ring_map = self.get_link_map(nslave)
# set of nodes that is pending for getting up
todo_nodes = range(nslave)
random.shuffle(todo_nodes)
else:
assert s.world_size == -1 or s.world_size == nslave
if s.cmd == 'recover':
assert s.rank >= 0
rank = s.decide_rank(job_map)
# batch assignment of ranks
if rank == -1:
assert len(todo_nodes) != 0
rank = todo_nodes.pop(0)
if s.jobid != 'NULL':
job_map[s.jobid] = rank
pending.append(s)
if len(pending) == len(todo_nodes):
pending.sort(key = lambda x : x.host)
for s in pending:
rank = todo_nodes.pop(0)
if s.jobid != 'NULL':
job_map[s.jobid] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
if len(todo_nodes) == 0:
self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
self.start_time = time.time()
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.cmd != 'start':
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
else:
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
if s.wait_accept > 0:
wait_conn[rank] = s
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
if s.wait_accept > 0:
wait_conn[rank] = s
self.log_print('@tracker All nodes finishes job', 2)
self.end_time = time.time()
self.log_print('@tracker %s secs between node start and job finish' % str(self.end_time - self.start_time), 2)