Merge pull request #25 from daiyl0320/master
add retry mechanism to ConnectTracker and modify Listen backlog to 128 in rabit_traker.py
This commit is contained in:
commit
e81a11dd7e
@ -24,6 +24,7 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
nport_trial = 1000;
|
nport_trial = 1000;
|
||||||
rank = 0;
|
rank = 0;
|
||||||
world_size = -1;
|
world_size = -1;
|
||||||
|
connect_retry = 5;
|
||||||
hadoop_mode = 0;
|
hadoop_mode = 0;
|
||||||
version_number = 0;
|
version_number = 0;
|
||||||
// 32 K items
|
// 32 K items
|
||||||
@ -46,6 +47,7 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
env_vars.push_back("DMLC_NUM_ATTEMPT");
|
env_vars.push_back("DMLC_NUM_ATTEMPT");
|
||||||
env_vars.push_back("DMLC_TRACKER_URI");
|
env_vars.push_back("DMLC_TRACKER_URI");
|
||||||
env_vars.push_back("DMLC_TRACKER_PORT");
|
env_vars.push_back("DMLC_TRACKER_PORT");
|
||||||
|
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialization function
|
// initialization function
|
||||||
@ -175,6 +177,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
if (!strcmp(name, "rabit_reduce_buffer")) {
|
if (!strcmp(name, "rabit_reduce_buffer")) {
|
||||||
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
|
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
|
||||||
}
|
}
|
||||||
|
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
|
||||||
|
connect_retry = atoi(val);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief initialize connection to the tracker
|
* \brief initialize connection to the tracker
|
||||||
@ -185,9 +190,23 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
|||||||
// get information from tracker
|
// get information from tracker
|
||||||
utils::TCPSocket tracker;
|
utils::TCPSocket tracker;
|
||||||
tracker.Create();
|
tracker.Create();
|
||||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
|
||||||
utils::Socket::Error("Connect");
|
int retry = 0;
|
||||||
}
|
do {
|
||||||
|
fprintf(stderr, "connect to ip: [%s]\n", tracker_uri.c_str());
|
||||||
|
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
||||||
|
if (++retry >= connect_retry) {
|
||||||
|
fprintf(stderr, "connect to (failed): [%s]\n", tracker_uri.c_str());
|
||||||
|
utils::Socket::Error("Connect");
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
|
||||||
|
sleep(1);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} while (1);
|
||||||
|
|
||||||
using utils::Assert;
|
using utils::Assert;
|
||||||
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||||
"ReConnectLink failure 1");
|
"ReConnectLink failure 1");
|
||||||
|
|||||||
@ -519,6 +519,8 @@ class AllreduceBase : public IEngine {
|
|||||||
int rank;
|
int rank;
|
||||||
// world size
|
// world size
|
||||||
int world_size;
|
int world_size;
|
||||||
|
// connect retry time
|
||||||
|
int connect_retry;
|
||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Tracker script for rabit
|
Tracker script for rabit
|
||||||
Implements the tracker control protocol
|
Implements the tracker control protocol
|
||||||
- start rabit jobs
|
- start rabit jobs
|
||||||
- help nodes to establish links with each other
|
- help nodes to establish links with each other
|
||||||
|
|
||||||
@ -19,13 +19,13 @@ from threading import Thread
|
|||||||
"""
|
"""
|
||||||
Extension of socket to handle recv and send of special data
|
Extension of socket to handle recv and send of special data
|
||||||
"""
|
"""
|
||||||
class ExSocket:
|
class ExSocket:
|
||||||
def __init__(self, sock):
|
def __init__(self, sock):
|
||||||
self.sock = sock
|
self.sock = sock
|
||||||
def recvall(self, nbytes):
|
def recvall(self, nbytes):
|
||||||
res = []
|
res = []
|
||||||
sock = self.sock
|
sock = self.sock
|
||||||
nread = 0
|
nread = 0
|
||||||
while nread < nbytes:
|
while nread < nbytes:
|
||||||
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
||||||
nread += len(chunk)
|
nread += len(chunk)
|
||||||
@ -106,7 +106,7 @@ class SlaveEntry:
|
|||||||
for r in conset:
|
for r in conset:
|
||||||
self.sock.sendstr(wait_conn[r].host)
|
self.sock.sendstr(wait_conn[r].host)
|
||||||
self.sock.sendint(wait_conn[r].port)
|
self.sock.sendint(wait_conn[r].port)
|
||||||
self.sock.sendint(r)
|
self.sock.sendint(r)
|
||||||
nerr = self.sock.recvint()
|
nerr = self.sock.recvint()
|
||||||
if nerr != 0:
|
if nerr != 0:
|
||||||
continue
|
continue
|
||||||
@ -121,7 +121,7 @@ class SlaveEntry:
|
|||||||
wait_conn.pop(r, None)
|
wait_conn.pop(r, None)
|
||||||
self.wait_accept = len(badset) - len(conset)
|
self.wait_accept = len(badset) - len(conset)
|
||||||
return rmset
|
return rmset
|
||||||
|
|
||||||
class Tracker:
|
class Tracker:
|
||||||
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
|
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
@ -132,7 +132,7 @@ class Tracker:
|
|||||||
break
|
break
|
||||||
except socket.error:
|
except socket.error:
|
||||||
continue
|
continue
|
||||||
sock.listen(16)
|
sock.listen(128)
|
||||||
self.sock = sock
|
self.sock = sock
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
if hostIP == 'auto':
|
if hostIP == 'auto':
|
||||||
@ -145,7 +145,7 @@ class Tracker:
|
|||||||
"""
|
"""
|
||||||
get enviroment variables for slaves
|
get enviroment variables for slaves
|
||||||
can be passed in as args or envs
|
can be passed in as args or envs
|
||||||
"""
|
"""
|
||||||
if self.hostIP == 'dns':
|
if self.hostIP == 'dns':
|
||||||
host = socket.gethostname()
|
host = socket.gethostname()
|
||||||
elif self.hostIP == 'ip':
|
elif self.hostIP == 'ip':
|
||||||
@ -153,14 +153,14 @@ class Tracker:
|
|||||||
else:
|
else:
|
||||||
host = self.hostIP
|
host = self.hostIP
|
||||||
return {'rabit_tracker_uri': host,
|
return {'rabit_tracker_uri': host,
|
||||||
'rabit_tracker_port': self.port}
|
'rabit_tracker_port': self.port}
|
||||||
def get_neighbor(self, rank, nslave):
|
def get_neighbor(self, rank, nslave):
|
||||||
rank = rank + 1
|
rank = rank + 1
|
||||||
ret = []
|
ret = []
|
||||||
if rank > 1:
|
if rank > 1:
|
||||||
ret.append(rank / 2 - 1)
|
ret.append(rank / 2 - 1)
|
||||||
if rank * 2 - 1 < nslave:
|
if rank * 2 - 1 < nslave:
|
||||||
ret.append(rank * 2 - 1)
|
ret.append(rank * 2 - 1)
|
||||||
if rank * 2 < nslave:
|
if rank * 2 < nslave:
|
||||||
ret.append(rank * 2)
|
ret.append(rank * 2)
|
||||||
return ret
|
return ret
|
||||||
@ -198,10 +198,10 @@ class Tracker:
|
|||||||
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
||||||
assert len(rlst) == len(tree_map)
|
assert len(rlst) == len(tree_map)
|
||||||
ring_map = {}
|
ring_map = {}
|
||||||
nslave = len(tree_map)
|
nslave = len(tree_map)
|
||||||
for r in range(nslave):
|
for r in range(nslave):
|
||||||
rprev = (r + nslave - 1) % nslave
|
rprev = (r + nslave - 1) % nslave
|
||||||
rnext = (r + 1) % nslave
|
rnext = (r + 1) % nslave
|
||||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||||
return ring_map
|
return ring_map
|
||||||
|
|
||||||
@ -231,7 +231,7 @@ class Tracker:
|
|||||||
else:
|
else:
|
||||||
parent_map_[rmap[k]] = -1
|
parent_map_[rmap[k]] = -1
|
||||||
return tree_map_, parent_map_, ring_map_
|
return tree_map_, parent_map_, ring_map_
|
||||||
|
|
||||||
def handle_print(self,slave, msg):
|
def handle_print(self,slave, msg):
|
||||||
sys.stdout.write(msg)
|
sys.stdout.write(msg)
|
||||||
|
|
||||||
@ -253,14 +253,14 @@ class Tracker:
|
|||||||
pending = []
|
pending = []
|
||||||
# lazy initialize tree_map
|
# lazy initialize tree_map
|
||||||
tree_map = None
|
tree_map = None
|
||||||
|
|
||||||
while len(shutdown) != nslave:
|
while len(shutdown) != nslave:
|
||||||
fd, s_addr = self.sock.accept()
|
fd, s_addr = self.sock.accept()
|
||||||
s = SlaveEntry(fd, s_addr)
|
s = SlaveEntry(fd, s_addr)
|
||||||
if s.cmd == 'print':
|
if s.cmd == 'print':
|
||||||
msg = s.sock.recvstr()
|
msg = s.sock.recvstr()
|
||||||
self.handle_print(s, msg)
|
self.handle_print(s, msg)
|
||||||
continue
|
continue
|
||||||
if s.cmd == 'shutdown':
|
if s.cmd == 'shutdown':
|
||||||
assert s.rank >= 0 and s.rank not in shutdown
|
assert s.rank >= 0 and s.rank not in shutdown
|
||||||
assert s.rank not in wait_conn
|
assert s.rank not in wait_conn
|
||||||
@ -280,12 +280,12 @@ class Tracker:
|
|||||||
assert s.world_size == -1 or s.world_size == nslave
|
assert s.world_size == -1 or s.world_size == nslave
|
||||||
if s.cmd == 'recover':
|
if s.cmd == 'recover':
|
||||||
assert s.rank >= 0
|
assert s.rank >= 0
|
||||||
|
|
||||||
rank = s.decide_rank(job_map)
|
rank = s.decide_rank(job_map)
|
||||||
# batch assignment of ranks
|
# batch assignment of ranks
|
||||||
if rank == -1:
|
if rank == -1:
|
||||||
assert len(todo_nodes) != 0
|
assert len(todo_nodes) != 0
|
||||||
pending.append(s)
|
pending.append(s)
|
||||||
if len(pending) == len(todo_nodes):
|
if len(pending) == len(todo_nodes):
|
||||||
pending.sort(key = lambda x : x.host)
|
pending.sort(key = lambda x : x.host)
|
||||||
for s in pending:
|
for s in pending:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user