[coll] Implement shutdown for tracker and comm. (#10208)
- Force shutdown the tracker. - Implement shutdown notice for error handling thread in comm.
This commit is contained in:
@@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
|
||||
try {
|
||||
for (auto &all_link : all_links) {
|
||||
if (!all_link.sock.IsClosed()) {
|
||||
all_link.sock.Close();
|
||||
SafeColl(all_link.sock.Close());
|
||||
}
|
||||
}
|
||||
all_links.clear();
|
||||
@@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
tracker.Send(xgboost::StringView{"shutdown"});
|
||||
tracker.Close();
|
||||
SafeColl(tracker.Close());
|
||||
xgboost::system::SocketFinalize();
|
||||
return true;
|
||||
} catch (std::exception const &e) {
|
||||
@@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
|
||||
tracker.Send(xgboost::StringView{"print"});
|
||||
tracker.Send(xgboost::StringView{msg});
|
||||
tracker.Close();
|
||||
SafeColl(tracker.Close());
|
||||
}
|
||||
|
||||
// util to parse data with unit suffix
|
||||
@@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
|
||||
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
|
||||
// create listening socket
|
||||
int port = sock_listen.BindHost();
|
||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||
sock_listen.Listen();
|
||||
std::int32_t port{0};
|
||||
SafeColl(sock_listen.BindHost(&port));
|
||||
SafeColl(sock_listen.Listen());
|
||||
|
||||
// get number of to connect and number of to accept nodes from tracker
|
||||
int num_conn, num_accept, num_error = 1;
|
||||
do {
|
||||
for (auto & all_link : all_links) {
|
||||
all_link.sock.Close();
|
||||
SafeColl(all_link.sock.Close());
|
||||
}
|
||||
// tracker construct goodset
|
||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
@@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
LinkRecord r;
|
||||
int hport, hrank;
|
||||
std::string hname;
|
||||
tracker.Recv(&hname);
|
||||
SafeColl(tracker.Recv(&hname));
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||
// connect to peer
|
||||
@@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
timeout_sec, &r.sock)
|
||||
.OK()) {
|
||||
num_error += 1;
|
||||
r.sock.Close();
|
||||
SafeColl(r.sock.Close());
|
||||
continue;
|
||||
}
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
@@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
// send back socket listening port to tracker
|
||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
||||
// close connection to tracker
|
||||
tracker.Close();
|
||||
SafeColl(tracker.Close());
|
||||
|
||||
// listen to incoming links
|
||||
for (int i = 0; i < num_accept; ++i) {
|
||||
@@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
}
|
||||
if (!match) all_links.emplace_back(std::move(r));
|
||||
}
|
||||
sock_listen.Close();
|
||||
SafeColl(sock_listen.Close());
|
||||
|
||||
this->parent_index = -1;
|
||||
// setup tree links and ring structure
|
||||
@@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||
|
||||
if (len == 0) {
|
||||
links[parent_index].sock.Close();
|
||||
SafeColl(links[parent_index].sock.Close());
|
||||
return ReportError(&links[parent_index], kRecvZeroLen);
|
||||
}
|
||||
if (len != -1) {
|
||||
|
||||
@@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
|
||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
sock.Close(); return kRecvZeroLen;
|
||||
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||
}
|
||||
if (len == -1) return Errno2Return();
|
||||
size_read += static_cast<size_t>(len);
|
||||
@@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
|
||||
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
sock.Close(); return kRecvZeroLen;
|
||||
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||
}
|
||||
if (len == -1) return Errno2Return();
|
||||
size_read += static_cast<size_t>(len);
|
||||
|
||||
Reference in New Issue
Block a user