fix hanging trainings (#132)
* fix hanging connections * remove logging
This commit is contained in:
parent
0d6a853212
commit
6e563951af
@ -133,8 +133,6 @@ bool AllreduceBase::Shutdown(void) {
|
|||||||
utils::TCPSocket tracker = this->ConnectTracker();
|
utils::TCPSocket tracker = this->ConnectTracker();
|
||||||
tracker.SendStr(std::string("shutdown"));
|
tracker.SendStr(std::string("shutdown"));
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
// close listening sockets
|
|
||||||
sock_listen.Close();
|
|
||||||
utils::TCPSocket::Finalize();
|
utils::TCPSocket::Finalize();
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
@ -282,6 +280,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
utils::TCPSocket tracker = this->ConnectTracker();
|
utils::TCPSocket tracker = this->ConnectTracker();
|
||||||
|
fprintf(stdout, "task %s connected to the tracker\n", task_id.c_str());
|
||||||
tracker.SendStr(std::string(cmd));
|
tracker.SendStr(std::string(cmd));
|
||||||
|
|
||||||
// the rank of previous link, next link in ring
|
// the rank of previous link, next link in ring
|
||||||
@ -304,6 +303,8 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
// tracker got overwhelemed and not able to assign correct rank
|
// tracker got overwhelemed and not able to assign correct rank
|
||||||
if (rank == -1) exit(-1);
|
if (rank == -1) exit(-1);
|
||||||
|
|
||||||
|
fprintf(stdout, "task %s got new rank %d\n", task_id.c_str(), rank);
|
||||||
|
|
||||||
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
||||||
sizeof(num_neighbors), "ReConnectLink failure 4");
|
sizeof(num_neighbors), "ReConnectLink failure 4");
|
||||||
for (int i = 0; i < num_neighbors; ++i) {
|
for (int i = 0; i < num_neighbors; ++i) {
|
||||||
@ -317,25 +318,15 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||||
"ReConnectLink failure 4");
|
"ReConnectLink failure 4");
|
||||||
|
|
||||||
if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) {
|
utils::TCPSocket sock_listen;
|
||||||
if (!sock_listen.IsClosed()) {
|
if (!sock_listen.IsClosed()) {
|
||||||
sock_listen.Close();
|
sock_listen.Close();
|
||||||
}
|
|
||||||
// create listening socket
|
|
||||||
sock_listen.Create();
|
|
||||||
sock_listen.SetKeepAlive(true);
|
|
||||||
// http://deepix.github.io/2016/10/21/tcprst.html
|
|
||||||
sock_listen.SetLinger(0);
|
|
||||||
// [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial)
|
|
||||||
// work around processes bind to same port without set reuse option,
|
|
||||||
// start explore from slave_port + newrank towards end
|
|
||||||
port = sock_listen.TryBindHost(slave_port + newrank % nport_trial, slave_port + nport_trial);
|
|
||||||
// if no port bindable, explore first half of range
|
|
||||||
if (port == -1) sock_listen.TryBindHost(slave_port, newrank % nport_trial + slave_port);
|
|
||||||
|
|
||||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
|
||||||
sock_listen.Listen();
|
|
||||||
}
|
}
|
||||||
|
// create listening socket
|
||||||
|
sock_listen.Create();
|
||||||
|
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
|
||||||
|
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||||
|
sock_listen.Listen();
|
||||||
|
|
||||||
// get number of to connect and number of to accept nodes from tracker
|
// get number of to connect and number of to accept nodes from tracker
|
||||||
int num_conn, num_accept, num_error = 1;
|
int num_conn, num_accept, num_error = 1;
|
||||||
@ -423,7 +414,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
}
|
}
|
||||||
if (!match) all_links.push_back(r);
|
if (!match) all_links.push_back(r);
|
||||||
}
|
}
|
||||||
|
sock_listen.Close();
|
||||||
this->parent_index = -1;
|
this->parent_index = -1;
|
||||||
// setup tree links and ring structure
|
// setup tree links and ring structure
|
||||||
tree_links.plinks.clear();
|
tree_links.plinks.clear();
|
||||||
|
|||||||
@ -571,10 +571,6 @@ class AllreduceBase : public IEngine {
|
|||||||
int world_size;
|
int world_size;
|
||||||
// connect retry time
|
// connect retry time
|
||||||
int connect_retry;
|
int connect_retry;
|
||||||
// backdoor listening peer connection
|
|
||||||
utils::TCPSocket sock_listen;
|
|
||||||
// backdoor port
|
|
||||||
int port = 0;
|
|
||||||
// enable bootstrap cache 0 false 1 true
|
// enable bootstrap cache 0 false 1 true
|
||||||
bool rabit_bootstrap_cache = false;
|
bool rabit_bootstrap_cache = false;
|
||||||
// enable detailed logging
|
// enable detailed logging
|
||||||
|
|||||||
@ -708,9 +708,6 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// print on tracker to help debuging
|
|
||||||
TrackerPrint("[ERROR] rank " + std::to_string(rank) + "@"+
|
|
||||||
host_uri + ":" +std::to_string(port) + " timeout\n");
|
|
||||||
_error("[%d] exit due to time out %d s\n", rank, timeout_sec);
|
_error("[%d] exit due to time out %d s\n", rank, timeout_sec);
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user