Merge pull request #25 from daiyl0320/master

add retry mechanism to ConnectTracker and modify Listen backlog to 128 in rabit_traker.py
This commit is contained in:
Tianqi Chen 2015-10-20 19:34:01 -07:00
commit e81a11dd7e
3 changed files with 40 additions and 19 deletions

View File

@ -24,6 +24,7 @@ AllreduceBase::AllreduceBase(void) {
nport_trial = 1000; nport_trial = 1000;
rank = 0; rank = 0;
world_size = -1; world_size = -1;
connect_retry = 5;
hadoop_mode = 0; hadoop_mode = 0;
version_number = 0; version_number = 0;
// 32 K items // 32 K items
@ -46,6 +47,7 @@ AllreduceBase::AllreduceBase(void) {
env_vars.push_back("DMLC_NUM_ATTEMPT"); env_vars.push_back("DMLC_NUM_ATTEMPT");
env_vars.push_back("DMLC_TRACKER_URI"); env_vars.push_back("DMLC_TRACKER_URI");
env_vars.push_back("DMLC_TRACKER_PORT"); env_vars.push_back("DMLC_TRACKER_PORT");
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
} }
// initialization function // initialization function
@ -175,6 +177,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_reduce_buffer")) { if (!strcmp(name, "rabit_reduce_buffer")) {
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3; reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
} }
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
connect_retry = atoi(val);
}
} }
/*! /*!
* \brief initialize connection to the tracker * \brief initialize connection to the tracker
@ -185,9 +190,23 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
// get information from tracker // get information from tracker
utils::TCPSocket tracker; utils::TCPSocket tracker;
tracker.Create(); tracker.Create();
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
utils::Socket::Error("Connect"); int retry = 0;
} do {
fprintf(stderr, "connect to ip: [%s]\n", tracker_uri.c_str());
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
if (++retry >= connect_retry) {
fprintf(stderr, "connect to (failed): [%s]\n", tracker_uri.c_str());
utils::Socket::Error("Connect");
} else {
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
sleep(1);
continue;
}
}
break;
} while (1);
using utils::Assert; using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 1"); "ReConnectLink failure 1");

View File

@ -519,6 +519,8 @@ class AllreduceBase : public IEngine {
int rank; int rank;
// world size // world size
int world_size; int world_size;
// connect retry time
int connect_retry;
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit

View File

@ -1,6 +1,6 @@
""" """
Tracker script for rabit Tracker script for rabit
Implements the tracker control protocol Implements the tracker control protocol
- start rabit jobs - start rabit jobs
- help nodes to establish links with each other - help nodes to establish links with each other
@ -19,13 +19,13 @@ from threading import Thread
""" """
Extension of socket to handle recv and send of special data Extension of socket to handle recv and send of special data
""" """
class ExSocket: class ExSocket:
def __init__(self, sock): def __init__(self, sock):
self.sock = sock self.sock = sock
def recvall(self, nbytes): def recvall(self, nbytes):
res = [] res = []
sock = self.sock sock = self.sock
nread = 0 nread = 0
while nread < nbytes: while nread < nbytes:
chunk = self.sock.recv(min(nbytes - nread, 1024)) chunk = self.sock.recv(min(nbytes - nread, 1024))
nread += len(chunk) nread += len(chunk)
@ -106,7 +106,7 @@ class SlaveEntry:
for r in conset: for r in conset:
self.sock.sendstr(wait_conn[r].host) self.sock.sendstr(wait_conn[r].host)
self.sock.sendint(wait_conn[r].port) self.sock.sendint(wait_conn[r].port)
self.sock.sendint(r) self.sock.sendint(r)
nerr = self.sock.recvint() nerr = self.sock.recvint()
if nerr != 0: if nerr != 0:
continue continue
@ -121,7 +121,7 @@ class SlaveEntry:
wait_conn.pop(r, None) wait_conn.pop(r, None)
self.wait_accept = len(badset) - len(conset) self.wait_accept = len(badset) - len(conset)
return rmset return rmset
class Tracker: class Tracker:
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'): def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -132,7 +132,7 @@ class Tracker:
break break
except socket.error: except socket.error:
continue continue
sock.listen(16) sock.listen(128)
self.sock = sock self.sock = sock
self.verbose = verbose self.verbose = verbose
if hostIP == 'auto': if hostIP == 'auto':
@ -145,7 +145,7 @@ class Tracker:
""" """
get enviroment variables for slaves get enviroment variables for slaves
can be passed in as args or envs can be passed in as args or envs
""" """
if self.hostIP == 'dns': if self.hostIP == 'dns':
host = socket.gethostname() host = socket.gethostname()
elif self.hostIP == 'ip': elif self.hostIP == 'ip':
@ -153,14 +153,14 @@ class Tracker:
else: else:
host = self.hostIP host = self.hostIP
return {'rabit_tracker_uri': host, return {'rabit_tracker_uri': host,
'rabit_tracker_port': self.port} 'rabit_tracker_port': self.port}
def get_neighbor(self, rank, nslave): def get_neighbor(self, rank, nslave):
rank = rank + 1 rank = rank + 1
ret = [] ret = []
if rank > 1: if rank > 1:
ret.append(rank / 2 - 1) ret.append(rank / 2 - 1)
if rank * 2 - 1 < nslave: if rank * 2 - 1 < nslave:
ret.append(rank * 2 - 1) ret.append(rank * 2 - 1)
if rank * 2 < nslave: if rank * 2 < nslave:
ret.append(rank * 2) ret.append(rank * 2)
return ret return ret
@ -198,10 +198,10 @@ class Tracker:
rlst = self.find_share_ring(tree_map, parent_map, 0) rlst = self.find_share_ring(tree_map, parent_map, 0)
assert len(rlst) == len(tree_map) assert len(rlst) == len(tree_map)
ring_map = {} ring_map = {}
nslave = len(tree_map) nslave = len(tree_map)
for r in range(nslave): for r in range(nslave):
rprev = (r + nslave - 1) % nslave rprev = (r + nslave - 1) % nslave
rnext = (r + 1) % nslave rnext = (r + 1) % nslave
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map return ring_map
@ -231,7 +231,7 @@ class Tracker:
else: else:
parent_map_[rmap[k]] = -1 parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_ return tree_map_, parent_map_, ring_map_
def handle_print(self,slave, msg): def handle_print(self,slave, msg):
sys.stdout.write(msg) sys.stdout.write(msg)
@ -253,14 +253,14 @@ class Tracker:
pending = [] pending = []
# lazy initialize tree_map # lazy initialize tree_map
tree_map = None tree_map = None
while len(shutdown) != nslave: while len(shutdown) != nslave:
fd, s_addr = self.sock.accept() fd, s_addr = self.sock.accept()
s = SlaveEntry(fd, s_addr) s = SlaveEntry(fd, s_addr)
if s.cmd == 'print': if s.cmd == 'print':
msg = s.sock.recvstr() msg = s.sock.recvstr()
self.handle_print(s, msg) self.handle_print(s, msg)
continue continue
if s.cmd == 'shutdown': if s.cmd == 'shutdown':
assert s.rank >= 0 and s.rank not in shutdown assert s.rank >= 0 and s.rank not in shutdown
assert s.rank not in wait_conn assert s.rank not in wait_conn
@ -280,12 +280,12 @@ class Tracker:
assert s.world_size == -1 or s.world_size == nslave assert s.world_size == -1 or s.world_size == nslave
if s.cmd == 'recover': if s.cmd == 'recover':
assert s.rank >= 0 assert s.rank >= 0
rank = s.decide_rank(job_map) rank = s.decide_rank(job_map)
# batch assignment of ranks # batch assignment of ranks
if rank == -1: if rank == -1:
assert len(todo_nodes) != 0 assert len(todo_nodes) != 0
pending.append(s) pending.append(s)
if len(pending) == len(todo_nodes): if len(pending) == len(todo_nodes):
pending.sort(key = lambda x : x.host) pending.sort(key = lambda x : x.host)
for s in pending: for s in pending: