[coll] Implement shutdown for tracker and comm. (#10208)

- Force shutdown the tracker.
- Implement shutdown notice for error handling thread in comm.
This commit is contained in:
Jiaming Yuan
2024-04-20 04:08:17 +08:00
committed by GitHub
parent 8fb05c8c95
commit 3fbb221fec
24 changed files with 553 additions and 199 deletions

View File

@@ -100,6 +100,24 @@ std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E
if ((revents & POLLNVAL) != 0) {
return xgboost::system::FailWithCode("Invalid polling request.");
}
if ((revents & POLLHUP) != 0) {
// Excerpt from the Linux manual:
//
// Note that when reading from a channel such as a pipe or a stream socket, this event
// merely indicates that the peer closed its end of the channel.Subsequent reads from
// the channel will return 0 (end of file) only after all outstanding data in the
// channel has been consumed.
//
// We don't usually have a barrier for exiting workers, it's normal to have one end
// exit while the other still reading data.
return xgboost::collective::Success();
}
#if defined(POLLRDHUP)
// Linux only flag
if ((revents & POLLRDHUP) != 0) {
return xgboost::system::FailWithCode("Poll hung up on the other end.");
}
#endif // defined(POLLRDHUP)
return xgboost::collective::Success();
}
@@ -179,9 +197,11 @@ struct PollHelper {
}
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) {
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
return xgboost::collective::Fail(
"Poll timeout:" + std::to_string(timeout.count()) + " seconds.",
std::make_error_code(std::errc::timed_out));
} else if (ret < 0) {
return xgboost::system::FailWithCode("Poll failed.");
return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size()));
}
for (auto& pfd : fdset) {

View File

@@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
try {
for (auto &all_link : all_links) {
if (!all_link.sock.IsClosed()) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
}
all_links.clear();
@@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close();
SafeColl(tracker.Close());
xgboost::system::SocketFinalize();
return true;
} catch (std::exception const &e) {
@@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
tracker.Close();
SafeColl(tracker.Close());
}
// util to parse data with unit suffix
@@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
// create listening socket
int port = sock_listen.BindHost();
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
std::int32_t port{0};
SafeColl(sock_listen.BindHost(&port));
SafeColl(sock_listen.Listen());
// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
for (auto & all_link : all_links) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
// tracker construct goodset
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
@@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.Recv(&hname);
SafeColl(tracker.Recv(&hname));
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
// connect to peer
@@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
timeout_sec, &r.sock)
.OK()) {
num_error += 1;
r.sock.Close();
SafeColl(r.sock.Close());
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
@@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
SafeColl(tracker.Close());
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
@@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
if (!match) all_links.emplace_back(std::move(r));
}
sock_listen.Close();
SafeColl(sock_listen.Close());
this->parent_index = -1;
// setup tree links and ring structure
@@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) {
links[parent_index].sock.Close();
SafeColl(links[parent_index].sock.Close());
return ReportError(&links[parent_index], kRecvZeroLen);
}
if (len != -1) {

View File

@@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
@@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);