diff --git a/src/rabit_tracker.py b/src/rabit_tracker.py index fe01a87da..c7068e31d 100644 --- a/src/rabit_tracker.py +++ b/src/rabit_tracker.py @@ -88,7 +88,6 @@ class SlaveEntry: self.sock.sendint(rnext) else: self.sock.sendint(-1) - while True: ngood = self.sock.recvint() goodset = set([]) @@ -156,13 +155,37 @@ class Tracker: tree_map[r] = self.get_neighbor(r, nslave) parent_map[r] = (r + 1) / 2 - 1 return tree_map, parent_map + def find_share_ring(self, tree_map, parent_map, r): + """ + get a ring structure that tends to share nodes with the tree + return a list starting from r + """ + nset = set(tree_map[r]) + cset = nset - set([parent_map[r]]) + if len(cset) == 0: + return [r] + rlst = [r] + cnt = 0 + for v in cset: + vlst = self.find_share_ring(tree_map, parent_map, v) + cnt += 1 + if cnt == len(cset): + vlst.reverse() + rlst += vlst + return rlst def get_ring(self, tree_map, parent_map): + """ + get a ring connection used to recover local data + """ + assert parent_map[0] == -1 + rlst = self.find_share_ring(tree_map, parent_map, 0) + assert len(rlst) == len(tree_map) ring_map = {} - nslave = len(tree_map) + nslave = len(tree_map) for r in range(nslave): rprev = (r + nslave - 1) % nslave rnext = (r + 1) % nslave - ring_map[r] = (rprev, rnext) + ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) return ring_map def accept_slaves(self, nslave): tree_map, parent_map = self.get_tree(nslave)