[coll] Implement shutdown for tracker and comm. (#10208)
- Force shutdown the tracker. - Implement shutdown notice for error handling thread in comm.
This commit is contained in:
@@ -100,6 +100,24 @@ std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E
|
||||
if ((revents & POLLNVAL) != 0) {
|
||||
return xgboost::system::FailWithCode("Invalid polling request.");
|
||||
}
|
||||
if ((revents & POLLHUP) != 0) {
|
||||
// Excerpt from the Linux manual:
|
||||
//
|
||||
// Note that when reading from a channel such as a pipe or a stream socket, this event
|
||||
// merely indicates that the peer closed its end of the channel.Subsequent reads from
|
||||
// the channel will return 0 (end of file) only after all outstanding data in the
|
||||
// channel has been consumed.
|
||||
//
|
||||
// We don't usually have a barrier for exiting workers, it's normal to have one end
|
||||
// exit while the other still reading data.
|
||||
return xgboost::collective::Success();
|
||||
}
|
||||
#if defined(POLLRDHUP)
|
||||
// Linux only flag
|
||||
if ((revents & POLLRDHUP) != 0) {
|
||||
return xgboost::system::FailWithCode("Poll hung up on the other end.");
|
||||
}
|
||||
#endif // defined(POLLRDHUP)
|
||||
return xgboost::collective::Success();
|
||||
}
|
||||
|
||||
@@ -179,9 +197,11 @@ struct PollHelper {
|
||||
}
|
||||
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||
if (ret == 0) {
|
||||
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
|
||||
return xgboost::collective::Fail(
|
||||
"Poll timeout:" + std::to_string(timeout.count()) + " seconds.",
|
||||
std::make_error_code(std::errc::timed_out));
|
||||
} else if (ret < 0) {
|
||||
return xgboost::system::FailWithCode("Poll failed.");
|
||||
return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size()));
|
||||
}
|
||||
|
||||
for (auto& pfd : fdset) {
|
||||
|
||||
@@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
|
||||
try {
|
||||
for (auto &all_link : all_links) {
|
||||
if (!all_link.sock.IsClosed()) {
|
||||
all_link.sock.Close();
|
||||
SafeColl(all_link.sock.Close());
|
||||
}
|
||||
}
|
||||
all_links.clear();
|
||||
@@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
tracker.Send(xgboost::StringView{"shutdown"});
|
||||
tracker.Close();
|
||||
SafeColl(tracker.Close());
|
||||
xgboost::system::SocketFinalize();
|
||||
return true;
|
||||
} catch (std::exception const &e) {
|
||||
@@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
|
||||
tracker.Send(xgboost::StringView{"print"});
|
||||
tracker.Send(xgboost::StringView{msg});
|
||||
tracker.Close();
|
||||
SafeColl(tracker.Close());
|
||||
}
|
||||
|
||||
// util to parse data with unit suffix
|
||||
@@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
|
||||
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
|
||||
// create listening socket
|
||||
int port = sock_listen.BindHost();
|
||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||
sock_listen.Listen();
|
||||
std::int32_t port{0};
|
||||
SafeColl(sock_listen.BindHost(&port));
|
||||
SafeColl(sock_listen.Listen());
|
||||
|
||||
// get number of to connect and number of to accept nodes from tracker
|
||||
int num_conn, num_accept, num_error = 1;
|
||||
do {
|
||||
for (auto & all_link : all_links) {
|
||||
all_link.sock.Close();
|
||||
SafeColl(all_link.sock.Close());
|
||||
}
|
||||
// tracker construct goodset
|
||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
@@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
LinkRecord r;
|
||||
int hport, hrank;
|
||||
std::string hname;
|
||||
tracker.Recv(&hname);
|
||||
SafeColl(tracker.Recv(&hname));
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||
// connect to peer
|
||||
@@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
timeout_sec, &r.sock)
|
||||
.OK()) {
|
||||
num_error += 1;
|
||||
r.sock.Close();
|
||||
SafeColl(r.sock.Close());
|
||||
continue;
|
||||
}
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
@@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
// send back socket listening port to tracker
|
||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
||||
// close connection to tracker
|
||||
tracker.Close();
|
||||
SafeColl(tracker.Close());
|
||||
|
||||
// listen to incoming links
|
||||
for (int i = 0; i < num_accept; ++i) {
|
||||
@@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
}
|
||||
if (!match) all_links.emplace_back(std::move(r));
|
||||
}
|
||||
sock_listen.Close();
|
||||
SafeColl(sock_listen.Close());
|
||||
|
||||
this->parent_index = -1;
|
||||
// setup tree links and ring structure
|
||||
@@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||
|
||||
if (len == 0) {
|
||||
links[parent_index].sock.Close();
|
||||
SafeColl(links[parent_index].sock.Close());
|
||||
return ReportError(&links[parent_index], kRecvZeroLen);
|
||||
}
|
||||
if (len != -1) {
|
||||
|
||||
@@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
|
||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
sock.Close(); return kRecvZeroLen;
|
||||
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||
}
|
||||
if (len == -1) return Errno2Return();
|
||||
size_read += static_cast<size_t>(len);
|
||||
@@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
|
||||
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
sock.Close(); return kRecvZeroLen;
|
||||
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||
}
|
||||
if (len == -1) return Errno2Return();
|
||||
size_read += static_cast<size_t>(len);
|
||||
|
||||
Reference in New Issue
Block a user