From 35c3b371ea9d283e7b1a83ffa7dd78324188227e Mon Sep 17 00:00:00 2001 From: "yonglong.dyl" Date: Wed, 21 Oct 2015 10:24:07 +0800 Subject: [PATCH] add retry mechanism to ConnectTracker and modify Listen backlog to 128 in rabit_traker.py --- src/allreduce_base.cc | 25 ++++++++++++++++++++++--- src/allreduce_base.h | 2 ++ tracker/rabit_tracker.py | 32 ++++++++++++++++---------------- 3 files changed, 40 insertions(+), 19 deletions(-) diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 12d88550e..917d1dffb 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -24,6 +24,7 @@ AllreduceBase::AllreduceBase(void) { nport_trial = 1000; rank = 0; world_size = -1; + connect_retry = 5; hadoop_mode = 0; version_number = 0; // 32 K items @@ -46,6 +47,7 @@ AllreduceBase::AllreduceBase(void) { env_vars.push_back("DMLC_NUM_ATTEMPT"); env_vars.push_back("DMLC_TRACKER_URI"); env_vars.push_back("DMLC_TRACKER_PORT"); + env_vars.push_back("DMLC_WORKER_CONNECT_RETRY"); } // initialization function @@ -175,6 +177,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "rabit_reduce_buffer")) { reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3; } + if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) { + connect_retry = atoi(val); + } } /*! * \brief initialize connection to the tracker @@ -185,9 +190,23 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const { // get information from tracker utils::TCPSocket tracker; tracker.Create(); - if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) { - utils::Socket::Error("Connect"); - } + + int retry = 0; + do { + fprintf(stderr, "connect to ip: [%s]\n", tracker_uri.c_str()); + if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) { + if (++retry >= connect_retry) { + fprintf(stderr, "connect to (failed): [%s]\n", tracker_uri.c_str()); + utils::Socket::Error("Connect"); + } else { + fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str()); + sleep(1); + continue; + } + } + break; + } while (1); + using utils::Assert; Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1"); diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 1d820b0b4..63acd75d5 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -519,6 +519,8 @@ class AllreduceBase : public IEngine { int rank; // world size int world_size; + // connect retry time + int connect_retry; }; } // namespace engine } // namespace rabit diff --git a/tracker/rabit_tracker.py b/tracker/rabit_tracker.py index c8dd896f1..d8e6ae84d 100644 --- a/tracker/rabit_tracker.py +++ b/tracker/rabit_tracker.py @@ -1,6 +1,6 @@ """ Tracker script for rabit -Implements the tracker control protocol +Implements the tracker control protocol - start rabit jobs - help nodes to establish links with each other @@ -19,13 +19,13 @@ from threading import Thread """ Extension of socket to handle recv and send of special data """ -class ExSocket: +class ExSocket: def __init__(self, sock): self.sock = sock def recvall(self, nbytes): res = [] sock = self.sock - nread = 0 + nread = 0 while nread < nbytes: chunk = self.sock.recv(min(nbytes - nread, 1024)) nread += len(chunk) @@ -106,7 +106,7 @@ class SlaveEntry: for r in conset: self.sock.sendstr(wait_conn[r].host) self.sock.sendint(wait_conn[r].port) - self.sock.sendint(r) + self.sock.sendint(r) nerr = self.sock.recvint() if nerr != 0: continue @@ -121,7 +121,7 @@ class SlaveEntry: wait_conn.pop(r, None) self.wait_accept = len(badset) - len(conset) return rmset - + class Tracker: def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -132,7 +132,7 @@ class Tracker: break except socket.error: continue - sock.listen(16) + sock.listen(128) self.sock = sock self.verbose = verbose if hostIP == 'auto': @@ -145,7 +145,7 @@ class Tracker: """ get enviroment variables for slaves can be passed in as args or envs - """ + """ if self.hostIP == 'dns': host = socket.gethostname() elif self.hostIP == 'ip': @@ -153,14 +153,14 @@ class Tracker: else: host = self.hostIP return {'rabit_tracker_uri': host, - 'rabit_tracker_port': self.port} + 'rabit_tracker_port': self.port} def get_neighbor(self, rank, nslave): rank = rank + 1 ret = [] if rank > 1: ret.append(rank / 2 - 1) if rank * 2 - 1 < nslave: - ret.append(rank * 2 - 1) + ret.append(rank * 2 - 1) if rank * 2 < nslave: ret.append(rank * 2) return ret @@ -198,10 +198,10 @@ class Tracker: rlst = self.find_share_ring(tree_map, parent_map, 0) assert len(rlst) == len(tree_map) ring_map = {} - nslave = len(tree_map) + nslave = len(tree_map) for r in range(nslave): rprev = (r + nslave - 1) % nslave - rnext = (r + 1) % nslave + rnext = (r + 1) % nslave ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) return ring_map @@ -231,7 +231,7 @@ class Tracker: else: parent_map_[rmap[k]] = -1 return tree_map_, parent_map_, ring_map_ - + def handle_print(self,slave, msg): sys.stdout.write(msg) @@ -253,14 +253,14 @@ class Tracker: pending = [] # lazy initialize tree_map tree_map = None - + while len(shutdown) != nslave: fd, s_addr = self.sock.accept() s = SlaveEntry(fd, s_addr) if s.cmd == 'print': msg = s.sock.recvstr() self.handle_print(s, msg) - continue + continue if s.cmd == 'shutdown': assert s.rank >= 0 and s.rank not in shutdown assert s.rank not in wait_conn @@ -280,12 +280,12 @@ class Tracker: assert s.world_size == -1 or s.world_size == nslave if s.cmd == 'recover': assert s.rank >= 0 - + rank = s.decide_rank(job_map) # batch assignment of ranks if rank == -1: assert len(todo_nodes) != 0 - pending.append(s) + pending.append(s) if len(pending) == len(todo_nodes): pending.sort(key = lambda x : x.host) for s in pending: