add retry mechanism to ConnectTracker and modify Listen backlog to 128
in rabit_traker.py
This commit is contained in:
parent
c71ed6fccb
commit
35c3b371ea
@ -24,6 +24,7 @@ AllreduceBase::AllreduceBase(void) {
|
||||
nport_trial = 1000;
|
||||
rank = 0;
|
||||
world_size = -1;
|
||||
connect_retry = 5;
|
||||
hadoop_mode = 0;
|
||||
version_number = 0;
|
||||
// 32 K items
|
||||
@ -46,6 +47,7 @@ AllreduceBase::AllreduceBase(void) {
|
||||
env_vars.push_back("DMLC_NUM_ATTEMPT");
|
||||
env_vars.push_back("DMLC_TRACKER_URI");
|
||||
env_vars.push_back("DMLC_TRACKER_PORT");
|
||||
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
|
||||
}
|
||||
|
||||
// initialization function
|
||||
@ -175,6 +177,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "rabit_reduce_buffer")) {
|
||||
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
|
||||
}
|
||||
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
|
||||
connect_retry = atoi(val);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
@ -185,9 +190,23 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||
// get information from tracker
|
||||
utils::TCPSocket tracker;
|
||||
tracker.Create();
|
||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
||||
utils::Socket::Error("Connect");
|
||||
}
|
||||
|
||||
int retry = 0;
|
||||
do {
|
||||
fprintf(stderr, "connect to ip: [%s]\n", tracker_uri.c_str());
|
||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
||||
if (++retry >= connect_retry) {
|
||||
fprintf(stderr, "connect to (failed): [%s]\n", tracker_uri.c_str());
|
||||
utils::Socket::Error("Connect");
|
||||
} else {
|
||||
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
|
||||
sleep(1);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
break;
|
||||
} while (1);
|
||||
|
||||
using utils::Assert;
|
||||
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||
"ReConnectLink failure 1");
|
||||
|
||||
@ -519,6 +519,8 @@ class AllreduceBase : public IEngine {
|
||||
int rank;
|
||||
// world size
|
||||
int world_size;
|
||||
// connect retry time
|
||||
int connect_retry;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""
|
||||
Tracker script for rabit
|
||||
Implements the tracker control protocol
|
||||
Implements the tracker control protocol
|
||||
- start rabit jobs
|
||||
- help nodes to establish links with each other
|
||||
|
||||
@ -19,13 +19,13 @@ from threading import Thread
|
||||
"""
|
||||
Extension of socket to handle recv and send of special data
|
||||
"""
|
||||
class ExSocket:
|
||||
class ExSocket:
|
||||
def __init__(self, sock):
|
||||
self.sock = sock
|
||||
def recvall(self, nbytes):
|
||||
res = []
|
||||
sock = self.sock
|
||||
nread = 0
|
||||
nread = 0
|
||||
while nread < nbytes:
|
||||
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
||||
nread += len(chunk)
|
||||
@ -106,7 +106,7 @@ class SlaveEntry:
|
||||
for r in conset:
|
||||
self.sock.sendstr(wait_conn[r].host)
|
||||
self.sock.sendint(wait_conn[r].port)
|
||||
self.sock.sendint(r)
|
||||
self.sock.sendint(r)
|
||||
nerr = self.sock.recvint()
|
||||
if nerr != 0:
|
||||
continue
|
||||
@ -121,7 +121,7 @@ class SlaveEntry:
|
||||
wait_conn.pop(r, None)
|
||||
self.wait_accept = len(badset) - len(conset)
|
||||
return rmset
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
@ -132,7 +132,7 @@ class Tracker:
|
||||
break
|
||||
except socket.error:
|
||||
continue
|
||||
sock.listen(16)
|
||||
sock.listen(128)
|
||||
self.sock = sock
|
||||
self.verbose = verbose
|
||||
if hostIP == 'auto':
|
||||
@ -145,7 +145,7 @@ class Tracker:
|
||||
"""
|
||||
get enviroment variables for slaves
|
||||
can be passed in as args or envs
|
||||
"""
|
||||
"""
|
||||
if self.hostIP == 'dns':
|
||||
host = socket.gethostname()
|
||||
elif self.hostIP == 'ip':
|
||||
@ -153,14 +153,14 @@ class Tracker:
|
||||
else:
|
||||
host = self.hostIP
|
||||
return {'rabit_tracker_uri': host,
|
||||
'rabit_tracker_port': self.port}
|
||||
'rabit_tracker_port': self.port}
|
||||
def get_neighbor(self, rank, nslave):
|
||||
rank = rank + 1
|
||||
ret = []
|
||||
if rank > 1:
|
||||
ret.append(rank / 2 - 1)
|
||||
if rank * 2 - 1 < nslave:
|
||||
ret.append(rank * 2 - 1)
|
||||
ret.append(rank * 2 - 1)
|
||||
if rank * 2 < nslave:
|
||||
ret.append(rank * 2)
|
||||
return ret
|
||||
@ -198,10 +198,10 @@ class Tracker:
|
||||
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
||||
assert len(rlst) == len(tree_map)
|
||||
ring_map = {}
|
||||
nslave = len(tree_map)
|
||||
nslave = len(tree_map)
|
||||
for r in range(nslave):
|
||||
rprev = (r + nslave - 1) % nslave
|
||||
rnext = (r + 1) % nslave
|
||||
rnext = (r + 1) % nslave
|
||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||
return ring_map
|
||||
|
||||
@ -231,7 +231,7 @@ class Tracker:
|
||||
else:
|
||||
parent_map_[rmap[k]] = -1
|
||||
return tree_map_, parent_map_, ring_map_
|
||||
|
||||
|
||||
def handle_print(self,slave, msg):
|
||||
sys.stdout.write(msg)
|
||||
|
||||
@ -253,14 +253,14 @@ class Tracker:
|
||||
pending = []
|
||||
# lazy initialize tree_map
|
||||
tree_map = None
|
||||
|
||||
|
||||
while len(shutdown) != nslave:
|
||||
fd, s_addr = self.sock.accept()
|
||||
s = SlaveEntry(fd, s_addr)
|
||||
if s.cmd == 'print':
|
||||
msg = s.sock.recvstr()
|
||||
self.handle_print(s, msg)
|
||||
continue
|
||||
continue
|
||||
if s.cmd == 'shutdown':
|
||||
assert s.rank >= 0 and s.rank not in shutdown
|
||||
assert s.rank not in wait_conn
|
||||
@ -280,12 +280,12 @@ class Tracker:
|
||||
assert s.world_size == -1 or s.world_size == nslave
|
||||
if s.cmd == 'recover':
|
||||
assert s.rank >= 0
|
||||
|
||||
|
||||
rank = s.decide_rank(job_map)
|
||||
# batch assignment of ranks
|
||||
if rank == -1:
|
||||
assert len(todo_nodes) != 0
|
||||
pending.append(s)
|
||||
pending.append(s)
|
||||
if len(pending) == len(todo_nodes):
|
||||
pending.sort(key = lambda x : x.host)
|
||||
for s in pending:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user