[coll] Implement shutdown for tracker and comm. (#10208)

- Force shutdown the tracker.
- Implement shutdown notice for error handling thread in comm.
This commit is contained in:
Jiaming Yuan
2024-04-20 04:08:17 +08:00
committed by GitHub
parent 8fb05c8c95
commit 3fbb221fec
24 changed files with 553 additions and 199 deletions

View File

@@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
try {
for (auto &all_link : all_links) {
if (!all_link.sock.IsClosed()) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
}
all_links.clear();
@@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close();
SafeColl(tracker.Close());
xgboost::system::SocketFinalize();
return true;
} catch (std::exception const &e) {
@@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
tracker.Close();
SafeColl(tracker.Close());
}
// util to parse data with unit suffix
@@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
// create listening socket
int port = sock_listen.BindHost();
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
std::int32_t port{0};
SafeColl(sock_listen.BindHost(&port));
SafeColl(sock_listen.Listen());
// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
for (auto & all_link : all_links) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
// tracker construct goodset
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
@@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.Recv(&hname);
SafeColl(tracker.Recv(&hname));
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
// connect to peer
@@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
timeout_sec, &r.sock)
.OK()) {
num_error += 1;
r.sock.Close();
SafeColl(r.sock.Close());
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
@@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
SafeColl(tracker.Close());
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
@@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
if (!match) all_links.emplace_back(std::move(r));
}
sock_listen.Close();
SafeColl(sock_listen.Close());
this->parent_index = -1;
// setup tree links and ring structure
@@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) {
links[parent_index].sock.Close();
SafeColl(links[parent_index].sock.Close());
return ReportError(&links[parent_index], kRecvZeroLen);
}
if (len != -1) {

View File

@@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
@@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);