diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index e065d8ba1..19b93c644 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1555,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle); /** * @brief Get the arguments needed for running workers. This should be called after - * XGTrackerRun() and XGTrackerWait() + * XGTrackerRun(). * * @param handle The handle to the tracker. * @param args The arguments returned as a JSON document. @@ -1565,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle); XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args); /** - * @brief Run the tracker. + * @brief Start the tracker. The tracker runs in the background and this function returns + * once the tracker is started. * * @param handle The handle to the tracker. + * @param config Unused at the moment, preserved for the future. * * @return 0 for success, -1 for failure. */ -XGB_DLL int XGTrackerRun(TrackerHandle handle); +XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config); /** - * @brief Wait for the tracker to finish, should be called after XGTrackerRun(). + * @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This + * function will block until the tracker task is finished or timeout is reached. * * @param handle The handle to the tracker. * @param config JSON encoded configuration. No argument is required yet, preserved for @@ -1582,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle); * * @return 0 for success, -1 for failure. */ -XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config); +XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config); /** - * @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker - * cannot close properly, manual interruption is required. + * @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the + * tracker is not properly waited, this function will shutdown all connections with + * the tracker, potentially leading to undefined behavior. * * @param handle The handle to the tracker. * diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 11520eede..0e098052c 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -124,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) { #endif } +inline std::int32_t ShutdownSocket(SocketT fd) { +#if defined(_WIN32) + auto rc = shutdown(fd, SD_BOTH); + if (rc != 0 && LastError() == WSANOTINITIALISED) { + return 0; + } +#else + auto rc = shutdown(fd, SHUT_RDWR); + if (rc != 0 && LastError() == ENOTCONN) { + return 0; + } +#endif + return rc; +} + inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) { #ifdef _WIN32 return errsv == WSAEWOULDBLOCK; @@ -499,36 +514,49 @@ class TCPSocket { */ [[nodiscard]] HandleT const &Handle() const { return handle_; } /** - * \brief Listen to incoming requests. Should be called after bind. + * @brief Listen to incoming requests. Should be called after bind. */ - void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); } + [[nodiscard]] Result Listen(std::int32_t backlog = 16) { + if (listen(handle_, backlog) != 0) { + return system::FailWithCode("Failed to listen."); + } + return Success(); + } /** - * \brief Bind socket to INADDR_ANY, return the port selected by the OS. + * @brief Bind socket to INADDR_ANY, return the port selected by the OS. */ - [[nodiscard]] in_port_t BindHost() { + [[nodiscard]] Result BindHost(std::int32_t* p_out) { + // Use int32 instead of in_port_t for consistency. We take port as parameter from + // users using other languages, the port is usually stored and passed around as int. if (Domain() == SockDomain::kV6) { auto addr = SockAddrV6::InaddrAny(); auto handle = reinterpret_cast(&addr.Handle()); - xgboost_CHECK_SYS_CALL( - bind(handle_, handle, sizeof(std::remove_reference_t)), 0); + if (bind(handle_, handle, sizeof(std::remove_reference_t)) != 0) { + return system::FailWithCode("bind failed."); + } sockaddr_in6 res_addr; socklen_t addrlen = sizeof(res_addr); - xgboost_CHECK_SYS_CALL( - getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0); - return ntohs(res_addr.sin6_port); + if (getsockname(handle_, reinterpret_cast(&res_addr), &addrlen) != 0) { + return system::FailWithCode("getsockname failed."); + } + *p_out = ntohs(res_addr.sin6_port); } else { auto addr = SockAddrV4::InaddrAny(); auto handle = reinterpret_cast(&addr.Handle()); - xgboost_CHECK_SYS_CALL( - bind(handle_, handle, sizeof(std::remove_reference_t)), 0); + if (bind(handle_, handle, sizeof(std::remove_reference_t)) != 0) { + return system::FailWithCode("bind failed."); + } sockaddr_in res_addr; socklen_t addrlen = sizeof(res_addr); - xgboost_CHECK_SYS_CALL( - getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0); - return ntohs(res_addr.sin_port); + if (getsockname(handle_, reinterpret_cast(&res_addr), &addrlen) != 0) { + return system::FailWithCode("getsockname failed."); + } + *p_out = ntohs(res_addr.sin_port); } + + return Success(); } [[nodiscard]] auto Port() const { @@ -641,13 +669,13 @@ class TCPSocket { */ std::size_t Send(StringView str); /** - * \brief Receive string, format is matched with the Python socket wrapper in RABIT. + * @brief Receive string, format is matched with the Python socket wrapper in RABIT. */ - std::size_t Recv(std::string *p_str); + [[nodiscard]] Result Recv(std::string *p_str); /** * @brief Close the socket, called automatically in destructor if the socket is not closed. */ - Result Close() { + [[nodiscard]] Result Close() { if (InvalidSocket() != handle_) { auto rc = system::CloseSocket(handle_); #if defined(_WIN32) @@ -664,6 +692,25 @@ class TCPSocket { } return Success(); } + /** + * @brief Call shutdown on the socket. + */ + [[nodiscard]] Result Shutdown() { + if (this->IsClosed()) { + return Success(); + } + auto rc = system::ShutdownSocket(this->Handle()); +#if defined(_WIN32) + // Windows cannot shutdown a socket if it's not connected. + if (rc == -1 && system::LastError() == WSAENOTCONN) { + return Success(); + } +#endif + if (rc != 0) { + return system::FailWithCode("Failed to shutdown socket."); + } + return Success(); + } /** * \brief Create a TCP socket on specified domain. diff --git a/plugin/federated/federated_tracker.cc b/plugin/federated/federated_tracker.cc index 37b6c3639..5051d43cb 100644 --- a/plugin/federated/federated_tracker.cc +++ b/plugin/federated/federated_tracker.cc @@ -125,14 +125,14 @@ Result FederatedTracker::Shutdown() { [[nodiscard]] Json FederatedTracker::WorkerArgs() const { auto rc = this->WaitUntilReady(); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); std::string host; rc = GetHostAddress(&host); CHECK(rc.OK()); Json args{Object{}}; - args["DMLC_TRACKER_URI"] = String{host}; - args["DMLC_TRACKER_PORT"] = this->Port(); + args["dmlc_tracker_uri"] = String{host}; + args["dmlc_tracker_port"] = this->Port(); return args; } } // namespace xgboost::collective diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index 89e324482..cec246efd 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -100,6 +100,24 @@ std::enable_if_t, xgboost::collective::Result> PollError(E if ((revents & POLLNVAL) != 0) { return xgboost::system::FailWithCode("Invalid polling request."); } + if ((revents & POLLHUP) != 0) { + // Excerpt from the Linux manual: + // + // Note that when reading from a channel such as a pipe or a stream socket, this event + // merely indicates that the peer closed its end of the channel.Subsequent reads from + // the channel will return 0 (end of file) only after all outstanding data in the + // channel has been consumed. + // + // We don't usually have a barrier for exiting workers, it's normal to have one end + // exit while the other still reading data. + return xgboost::collective::Success(); + } +#if defined(POLLRDHUP) + // Linux only flag + if ((revents & POLLRDHUP) != 0) { + return xgboost::system::FailWithCode("Poll hung up on the other end."); + } +#endif // defined(POLLRDHUP) return xgboost::collective::Success(); } @@ -179,9 +197,11 @@ struct PollHelper { } std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout); if (ret == 0) { - return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out)); + return xgboost::collective::Fail( + "Poll timeout:" + std::to_string(timeout.count()) + " seconds.", + std::make_error_code(std::errc::timed_out)); } else if (ret < 0) { - return xgboost::system::FailWithCode("Poll failed."); + return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size())); } for (auto& pfd : fdset) { diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index b99eb3763..fcf80b414 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() { try { for (auto &all_link : all_links) { if (!all_link.sock.IsClosed()) { - all_link.sock.Close(); + SafeColl(all_link.sock.Close()); } } all_links.clear(); @@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() { LOG(FATAL) << rc.Report(); } tracker.Send(xgboost::StringView{"shutdown"}); - tracker.Close(); + SafeColl(tracker.Close()); xgboost::system::SocketFinalize(); return true; } catch (std::exception const &e) { @@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) { tracker.Send(xgboost::StringView{"print"}); tracker.Send(xgboost::StringView{msg}); - tracker.Close(); + SafeColl(tracker.Close()); } // util to parse data with unit suffix @@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) { auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())}; // create listening socket - int port = sock_listen.BindHost(); - utils::Check(port != -1, "ReConnectLink fail to bind the ports specified"); - sock_listen.Listen(); + std::int32_t port{0}; + SafeColl(sock_listen.BindHost(&port)); + SafeColl(sock_listen.Listen()); // get number of to connect and number of to accept nodes from tracker int num_conn, num_accept, num_error = 1; do { for (auto & all_link : all_links) { - all_link.sock.Close(); + SafeColl(all_link.sock.Close()); } // tracker construct goodset Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), @@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { LinkRecord r; int hport, hrank; std::string hname; - tracker.Recv(&hname); + SafeColl(tracker.Recv(&hname)); Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); // connect to peer @@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { timeout_sec, &r.sock) .OK()) { num_error += 1; - r.sock.Close(); + SafeColl(r.sock.Close()); continue; } Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), @@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { // send back socket listening port to tracker Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14"); // close connection to tracker - tracker.Close(); + SafeColl(tracker.Close()); // listen to incoming links for (int i = 0; i < num_accept; ++i) { @@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { } if (!match) all_links.emplace_back(std::move(r)); } - sock_listen.Close(); + SafeColl(sock_listen.Close()); this->parent_index = -1; // setup tree links and ring structure @@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, Recv(sendrecvbuf + size_down_in, total_size - size_down_in); if (len == 0) { - links[parent_index].sock.Close(); + SafeColl(links[parent_index].sock.Close()); return ReportError(&links[parent_index], kRecvZeroLen); } if (len != -1) { diff --git a/rabit/src/allreduce_base.h b/rabit/src/allreduce_base.h index 7724bf3d5..9991c2138 100644 --- a/rabit/src/allreduce_base.h +++ b/rabit/src/allreduce_base.h @@ -270,7 +270,7 @@ class AllreduceBase : public IEngine { ssize_t len = sock.Recv(buffer_head + offset, nmax); // length equals 0, remote disconnected if (len == 0) { - sock.Close(); return kRecvZeroLen; + SafeColl(sock.Close()); return kRecvZeroLen; } if (len == -1) return Errno2Return(); size_read += static_cast(len); @@ -289,7 +289,7 @@ class AllreduceBase : public IEngine { ssize_t len = sock.Recv(p + size_read, max_size - size_read); // length equals 0, remote disconnected if (len == 0) { - sock.Close(); return kRecvZeroLen; + SafeColl(sock.Close()); return kRecvZeroLen; } if (len == -1) return Errno2Return(); size_read += static_cast(len); diff --git a/src/c_api/coll_c_api.cc b/src/c_api/coll_c_api.cc index 24e94f3de..fba2647cc 100644 --- a/src/c_api/coll_c_api.cc +++ b/src/c_api/coll_c_api.cc @@ -5,9 +5,11 @@ #include // for future #include // for unique_ptr #include // for string +#include // for sleep_for #include // for is_same_v, remove_pointer_t #include // for pair +#include "../collective/comm.h" // for DefaultTimeoutSec #include "../collective/tracker.h" // for RabitTracker #include "../common/timer.h" // for Timer #include "c_api_error.h" // for API_BEGIN @@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT namespace { using TrackerHandleT = - std::pair, std::shared_future>; + std::pair, std::shared_future>; TrackerHandleT *GetTrackerHandle(TrackerHandle handle) { xgboost_CHECK_C_ARG_PTR(handle); @@ -41,12 +43,14 @@ struct CollAPIEntry { using CollAPIThreadLocalStore = dmlc::ThreadLocalStore; void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) { - constexpr std::int64_t kDft{60}; + constexpr std::int64_t kDft{collective::DefaultTimeoutSec()}; std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft}; common::Timer timer; timer.Start(); + auto ref = ptr->first; // hold a reference to that free don't delete it while waiting. + auto fut = ptr->second; while (fut.valid()) { auto res = fut.wait_for(wait_for); @@ -72,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) { Json jconfig = Json::Load(config); auto type = RequiredArg(jconfig, "dmlc_communicator", __func__); - std::unique_ptr tptr; + std::shared_ptr tptr; if (type == "federated") { #if defined(XGBOOST_USE_FEDERATED) - tptr = std::make_unique(jconfig); + tptr = std::make_shared(jconfig); #else LOG(FATAL) << error::NoFederated(); #endif // defined(XGBOOST_USE_FEDERATED) } else if (type == "rabit") { - tptr = std::make_unique(jconfig); + tptr = std::make_shared(jconfig); } else { LOG(FATAL) << "Unknown communicator:" << type; } @@ -103,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) { API_END(); } -XGB_DLL int XGTrackerRun(TrackerHandle handle) { +XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) { API_BEGIN(); auto *ptr = GetTrackerHandle(handle); CHECK(!ptr->second.valid()) << "Tracker is already running."; @@ -111,13 +115,14 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) { API_END(); } -XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { +XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) { API_BEGIN(); auto *ptr = GetTrackerHandle(handle); xgboost_CHECK_C_ARG_PTR(config); auto jconfig = Json::Load(StringView{config}); // Internally, 0 indicates no timeout, which is the default since we don't want to // interrupt the model training. + xgboost_CHECK_C_ARG_PTR(config); auto timeout = OptionalArg(jconfig, "timeout", std::int64_t{0}); WaitImpl(ptr, std::chrono::seconds{timeout}); API_END(); @@ -125,8 +130,24 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { XGB_DLL int XGTrackerFree(TrackerHandle handle) { API_BEGIN(); + using namespace std::chrono_literals; // NOLINT auto *ptr = GetTrackerHandle(handle); + ptr->first->Stop(); + // The wait is not necessary since we just called stop, just reusing the function to do + // any potential cleanups. WaitImpl(ptr, ptr->first->Timeout()); + common::Timer timer; + timer.Start(); + // Make sure no one else is waiting on the tracker. + while (!ptr->first.unique()) { + auto ela = timer.Duration().count(); + if (ela > ptr->first->Timeout().count()) { + LOG(WARNING) << "Time out " << ptr->first->Timeout().count() + << " seconds reached for TrackerFree, killing the tracker."; + break; + } + std::this_thread::sleep_for(64ms); + } delete ptr; API_END(); } diff --git a/src/collective/coll.cc b/src/collective/coll.cc index c6d03c6df..b720d09b7 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -38,6 +38,10 @@ bool constexpr IsFloatingPointV() { auto redop_fn = [](auto lhs, auto out, auto elem_op) { auto p_lhs = lhs.data(); auto p_out = out.data(); +#if defined(__GNUC__) || defined(__clang__) + // For the sum op, one can verify the simd by: addps %xmm15, %xmm14 +#pragma omp simd +#endif for (std::size_t i = 0; i < lhs.size(); ++i) { p_out[i] = elem_op(p_lhs[i], p_out[i]); } diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 23a8e89ed..50a14aaaf 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -5,9 +5,11 @@ #include // for copy #include // for seconds +#include // for int32_t #include // for exit #include // for shared_ptr #include // for string +#include // for thread #include // for move, forward #if !defined(XGBOOST_USE_NCCL) #include "../common/common.h" // for AssertNCCLSupport @@ -184,13 +186,30 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st return Success(); } -RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, - std::int32_t retry, std::string task_id, StringView nccl_path) - : HostComm{std::move(host), port, timeout, retry, std::move(task_id)}, +namespace { +std::string InitLog(std::string task_id, std::int32_t rank) { + if (task_id.empty()) { + return "Rank " + std::to_string(rank); + } + return "Task " + task_id + " got rank " + std::to_string(rank); +} +} // namespace + +RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port, + std::chrono::seconds timeout, std::int32_t retry, std::string task_id, + StringView nccl_path) + : HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)}, nccl_path_{std::move(nccl_path)} { + if (this->TrackerInfo().host.empty()) { + // Not in a distributed environment. + LOG(CONSOLE) << InitLog(task_id_, rank_); + return; + } + loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT auto rc = this->Bootstrap(timeout_, retry_, task_id_); if (!rc.OK()) { + this->ResetState(); SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc))); } } @@ -217,20 +236,54 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { // Start command TCPSocket listener = TCPSocket::Create(tracker.Domain()); - std::int32_t lport = listener.BindHost(); - listener.Listen(); + std::int32_t lport{0}; + rc = std::move(rc) << [&] { + return listener.BindHost(&lport); + } << [&] { + return listener.Listen(); + }; + if (!rc.OK()) { + return rc; + } // create worker for listening to error notice. auto domain = tracker.Domain(); std::shared_ptr error_sock{TCPSocket::CreatePtr(domain)}; - auto eport = error_sock->BindHost(); - error_sock->Listen(); + std::int32_t eport{0}; + rc = std::move(rc) << [&] { + return error_sock->BindHost(&eport); + } << [&] { + return error_sock->Listen(); + }; + if (!rc.OK()) { + return rc; + } + error_port_ = eport; + error_worker_ = std::thread{[error_sock = std::move(error_sock)] { - auto conn = error_sock->Accept(); + TCPSocket conn; + SockAddress addr; + auto rc = error_sock->Accept(&conn, &addr); + // On Linux, a shutdown causes an invalid argument error; + if (rc.Code() == std::errc::invalid_argument) { + return; + } // On Windows, accept returns a closed socket after finalize. if (conn.IsClosed()) { return; } + // The error signal is from the tracker, while shutdown signal is from the shutdown method + // of the RabitComm class (this). + bool is_error{false}; + rc = proto::Error{}.RecvSignal(&conn, &is_error); + if (!rc.OK()) { + LOG(WARNING) << rc.Report(); + return; + } + if (!is_error) { + return; // shutdown + } + LOG(WARNING) << "Another worker is running into error."; #if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 // exit is nicer than abort as the former performs cleanups. @@ -239,6 +292,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { LOG(FATAL) << "abort"; #endif }}; + // The worker thread is detached here to avoid the need to handle it later during + // destruction. For C++, if a thread is not joined or detached, it will segfault during + // destruction. error_worker_.detach(); proto::Start start; @@ -251,7 +307,10 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { // get ring neighbors std::string snext; - tracker.Recv(&snext); + rc = tracker.Recv(&snext); + if (!rc.OK()) { + return Fail("Failed to receive the rank for the next worker.", std::move(rc)); + } auto jnext = Json::Load(StringView{snext}); proto::PeerInfo ninfo{jnext}; @@ -268,14 +327,21 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { CHECK(this->channels_.empty()); for (auto& w : workers) { if (w) { - rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); } - << [&] { return w->SetKeepAlive(); }; + rc = std::move(rc) << [&] { + return w->SetNoDelay(); + } << [&] { + return w->NonBlocking(true); + } << [&] { + return w->SetKeepAlive(); + }; } if (!rc.OK()) { return rc; } this->channels_.emplace_back(std::make_shared(*this, w)); } + + LOG(CONSOLE) << InitLog(task_id_, rank_); return rc; } @@ -283,6 +349,8 @@ RabitComm::~RabitComm() noexcept(false) { if (!this->IsDistributed()) { return; } + LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can " + "lead to undefined behaviour."; auto rc = this->Shutdown(); if (!rc.OK()) { LOG(WARNING) << rc.Report(); @@ -293,30 +361,49 @@ RabitComm::~RabitComm() noexcept(false) { if (!this->IsDistributed()) { return Success(); } - + // Tell the tracker that this worker is shutting down. TCPSocket tracker; + // Tell the error hanlding thread that we are shutting down. + TCPSocket err_client; + return Success() << [&] { return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World()); } << [&] { return this->Block(); } << [&] { - Json jcmd{Object{}}; - jcmd["cmd"] = Integer{static_cast(proto::CMD::kShutdown)}; - auto scmd = Json::Dump(jcmd); - auto n_bytes = tracker.Send(scmd); - if (n_bytes != scmd.size()) { - return Fail("Faled to send cmd."); - } - - this->ResetState(); - return Success(); + return proto::ShutdownCMD{}.Send(&tracker); } << [&] { this->channels_.clear(); return Success(); + } << [&] { + // Use tracker address to determine whether we want to use IPv6. + auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port); + // Shutdown the error handling thread. We signal the thread through socket, + // alternatively, we can get the native handle and use pthread_cancel. But using a + // socket seems to be clearer as we know what's happening. + auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr(); + // We use hardcoded 10 seconds and 1 retry here since we are just connecting to a + // local socket. For a normal OS, this should be enough time to schedule the + // connection. + auto rc = Connect(StringView{addr}, this->error_port_, 1, + std::min(std::chrono::seconds{10}, timeout_), &err_client); + this->ResetState(); + if (!rc.OK()) { + return Fail("Failed to connect to the error socket.", std::move(rc)); + } + return rc; + } << [&] { + // We put error thread shutdown at the end so that we have a better chance to finish + // the previous more important steps. + return proto::Error{}.SignalShutdown(&err_client); }; } [[nodiscard]] Result RabitComm::LogTracker(std::string msg) const { + if (!this->IsDistributed()) { + LOG(CONSOLE) << msg; + return Success(); + } TCPSocket out; proto::Print print; return Success() << [&] { return this->ConnectTracker(&out); } @@ -324,8 +411,11 @@ RabitComm::~RabitComm() noexcept(false) { } [[nodiscard]] Result RabitComm::SignalError(Result const& res) { - TCPSocket out; - return Success() << [&] { return this->ConnectTracker(&out); } - << [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); }; + TCPSocket tracker; + return Success() << [&] { + return this->ConnectTracker(&tracker); + } << [&] { + return proto::ErrorCMD{}.WorkerSend(&tracker, res); + }; } } // namespace xgboost::collective diff --git a/src/collective/comm.h b/src/collective/comm.h index 6ad5bc5c1..a41f47be9 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -1,10 +1,10 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for seconds #include // for size_t -#include // for int32_t +#include // for int32_t, int64_t #include // for shared_ptr #include // for string #include // for thread @@ -20,7 +20,7 @@ namespace xgboost::collective { -inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min +inline constexpr std::int64_t DefaultTimeoutSec() { return 300; } // 5min inline constexpr std::int32_t DefaultRetry() { return 3; } // indexing into the ring @@ -51,7 +51,10 @@ class Comm : public std::enable_shared_from_this { proto::PeerInfo tracker_; SockDomain domain_{SockDomain::kV4}; + std::thread error_worker_; + std::int32_t error_port_; + std::string task_id_; std::vector> channels_; std::shared_ptr loop_{nullptr}; // fixme: require federated comm to have a timeout @@ -59,6 +62,13 @@ class Comm : public std::enable_shared_from_this { void ResetState() { this->world_ = -1; this->rank_ = 0; + this->timeout_ = std::chrono::seconds{DefaultTimeoutSec()}; + + tracker_ = proto::PeerInfo{}; + this->task_id_.clear(); + channels_.clear(); + + loop_.reset(); } public: @@ -79,9 +89,9 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] auto Retry() const { return retry_; } [[nodiscard]] auto TaskID() const { return task_id_; } - [[nodiscard]] auto Rank() const { return rank_; } - [[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; } - [[nodiscard]] bool IsDistributed() const { return world_ != -1; } + [[nodiscard]] auto Rank() const noexcept { return rank_; } + [[nodiscard]] auto World() const noexcept { return IsDistributed() ? world_ : 1; } + [[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; } void Submit(Loop::Op op) const { CHECK(loop_); loop_->Submit(op); @@ -120,20 +130,20 @@ class RabitComm : public HostComm { [[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry, std::string task_id); - [[nodiscard]] Result Shutdown() final; public: // bootstrapping construction. RabitComm() = default; - // ctor for testing where environment is known. - RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, - std::int32_t retry, std::string task_id, StringView nccl_path); + RabitComm(std::string const& tracker_host, std::int32_t tracker_port, + std::chrono::seconds timeout, std::int32_t retry, std::string task_id, + StringView nccl_path); ~RabitComm() noexcept(false) override; [[nodiscard]] bool IsFederated() const override { return false; } [[nodiscard]] Result LogTracker(std::string msg) const override; [[nodiscard]] Result SignalError(Result const&) override; + [[nodiscard]] Result Shutdown() final; [[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const override; }; diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc index 7408882f6..18a5ba8a7 100644 --- a/src/collective/comm_group.cc +++ b/src/collective/comm_group.cc @@ -64,6 +64,9 @@ CommGroup::CommGroup() auto const& obj = get(config); auto it = obj.find(upper); + if (it != obj.cend() && obj.find(name) != obj.cend()) { + LOG(FATAL) << "Duplicated parameter:" << name; + } if (it != obj.cend()) { return OptionalArg(config, upper, dft); } else { @@ -77,14 +80,14 @@ CommGroup::CommGroup() auto task_id = get_param("dmlc_task_id", std::string{}, String{}); if (type == "rabit") { - auto host = get_param("dmlc_tracker_uri", std::string{}, String{}); - auto port = get_param("dmlc_tracker_port", static_cast(0), Integer{}); + auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{}); + auto tracker_port = get_param("dmlc_tracker_port", static_cast(0), Integer{}); auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{}); - auto ptr = - new CommGroup{std::shared_ptr{new RabitComm{ // NOLINT - host, static_cast(port), std::chrono::seconds{timeout}, - static_cast(retry), task_id, nccl}}, - std::shared_ptr(new Coll{})}; // NOLINT + auto ptr = new CommGroup{ + std::shared_ptr{new RabitComm{ // NOLINT + tracker_host, static_cast(tracker_port), std::chrono::seconds{timeout}, + static_cast(retry), task_id, nccl}}, + std::shared_ptr(new Coll{})}; // NOLINT return ptr; } else if (type == "federated") { #if defined(XGBOOST_USE_FEDERATED) diff --git a/src/collective/comm_group.h b/src/collective/comm_group.h index 61a58ba56..a98de0c16 100644 --- a/src/collective/comm_group.h +++ b/src/collective/comm_group.h @@ -30,9 +30,9 @@ class CommGroup { public: CommGroup(); - [[nodiscard]] auto World() const { return comm_->World(); } - [[nodiscard]] auto Rank() const { return comm_->Rank(); } - [[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); } + [[nodiscard]] auto World() const noexcept { return comm_->World(); } + [[nodiscard]] auto Rank() const noexcept { return comm_->Rank(); } + [[nodiscard]] bool IsDistributed() const noexcept { return comm_->IsDistributed(); } [[nodiscard]] Result Finalize() const { return Success() << [this] { diff --git a/src/collective/protocol.h b/src/collective/protocol.h index 96edf4e29..29e6c9619 100644 --- a/src/collective/protocol.h +++ b/src/collective/protocol.h @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int32_t @@ -58,6 +58,7 @@ struct Magic { } }; +// Basic commands for communication between workers and the tracker. enum class CMD : std::int32_t { kInvalid = 0, kStart = 1, @@ -84,7 +85,10 @@ struct Connect { [[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank, std::string* task_id) const { std::string init; - sock->Recv(&init); + auto rc = sock->Recv(&init); + if (!rc.OK()) { + return Fail("Connect protocol failed.", std::move(rc)); + } auto jinit = Json::Load(StringView{init}); *world = get(jinit["world_size"]); *rank = get(jinit["rank"]); @@ -122,9 +126,9 @@ class Start { } [[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const { std::string scmd; - auto n_bytes = tracker->Recv(&scmd); - if (n_bytes <= 0) { - return Fail("Failed to recv init command from tracker."); + auto rc = tracker->Recv(&scmd); + if (!rc.OK()) { + return Fail("Failed to recv init command from tracker.", std::move(rc)); } auto jcmd = Json::Load(scmd); auto world = get(jcmd["world_size"]); @@ -132,7 +136,7 @@ class Start { return Fail("Invalid world size."); } *p_world = world; - return Success(); + return rc; } [[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world, std::int32_t* p_port, TCPSocket* p_sock, @@ -150,6 +154,7 @@ class Start { } }; +// Protocol for communicating with the tracker for printing message. struct Print { [[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const { Json jcmd{Object{}}; @@ -172,6 +177,7 @@ struct Print { } }; +// Protocol for communicating with the tracker during error. struct ErrorCMD { [[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const { auto msg = res.Report(); @@ -199,6 +205,7 @@ struct ErrorCMD { } }; +// Protocol for communicating with the tracker during shutdown. struct ShutdownCMD { [[nodiscard]] Result Send(TCPSocket* peer) const { Json jcmd{Object{}}; @@ -211,4 +218,40 @@ struct ShutdownCMD { return Success(); } }; + +// Protocol for communicating with the local error handler during error or shutdown. Only +// one protocol that doesn't have the tracker involved. +struct Error { + constexpr static std::int32_t ShutdownSignal() { return 0; } + constexpr static std::int32_t ErrorSignal() { return -1; } + + [[nodiscard]] Result SignalError(TCPSocket* worker) const { + std::int32_t err{ErrorSignal()}; + auto n_sent = worker->SendAll(&err, sizeof(err)); + if (n_sent == sizeof(err)) { + return Success(); + } + return Fail("Failed to send error signal"); + } + // self is localhost, we are sending the signal to the error handling thread for it to + // close. + [[nodiscard]] Result SignalShutdown(TCPSocket* self) const { + std::int32_t err{ShutdownSignal()}; + auto n_sent = self->SendAll(&err, sizeof(err)); + if (n_sent == sizeof(err)) { + return Success(); + } + return Fail("Failed to send shutdown signal"); + } + // get signal, either for error or for shutdown. + [[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const { + std::int32_t err{ShutdownSignal()}; + auto n_recv = peer->RecvAll(&err, sizeof(err)); + if (n_recv == sizeof(err)) { + *p_is_error = err == 1; + return Success(); + } + return Fail("Failed to receive error signal."); + } +}; } // namespace xgboost::collective::proto diff --git a/src/collective/socket.cc b/src/collective/socket.cc index 43da366bd..737ce584e 100644 --- a/src/collective/socket.cc +++ b/src/collective/socket.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include "xgboost/collective/socket.h" @@ -8,7 +8,8 @@ #include // std::int32_t #include // std::memcpy, std::memset #include // for path -#include // std::error_code, std::system_category +#include // for error_code, system_category +#include // for sleep_for #include "rabit/internal/socket.h" // for PollHelper #include "xgboost/collective/result.h" // for Result @@ -65,14 +66,18 @@ std::size_t TCPSocket::Send(StringView str) { return bytes; } -std::size_t TCPSocket::Recv(std::string *p_str) { +[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) { CHECK(!this->IsClosed()); std::int32_t len; - CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length."; + if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) { + return Fail("Failed to recv string length."); + } p_str->resize(len); auto bytes = this->RecvAll(&(*p_str)[0], len); - CHECK_EQ(bytes, len) << "Failed to recv string."; - return bytes; + if (static_cast(bytes) != len) { + return Fail("Failed to recv string."); + } + return Success(); } [[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, @@ -110,11 +115,7 @@ std::size_t TCPSocket::Recv(std::string *p_str) { for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) { if (attempt > 0) { LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time."; -#if defined(_MSC_VER) || defined(__MINGW32__) - Sleep(attempt << 1); -#else - sleep(attempt << 1); -#endif + std::this_thread::sleep_for(std::chrono::seconds{attempt << 1}); } auto rc = connect(conn.Handle(), addr_handle, addr_len); @@ -158,8 +159,8 @@ std::size_t TCPSocket::Recv(std::string *p_str) { std::stringstream ss; ss << "Failed to connect to " << host << ":" << port; - conn.Close(); - return Fail(ss.str(), std::move(last_error)); + auto close_rc = conn.Close(); + return Fail(ss.str(), std::move(close_rc) + std::move(last_error)); } [[nodiscard]] Result GetHostName(std::string *p_out) { diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index 3fdf75ead..142483ccf 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -1,6 +1,7 @@ /** * Copyright 2023-2024, XGBoost Contributors */ +#include "rabit/internal/socket.h" #if defined(__unix__) || defined(__APPLE__) #include // gethostbyname #include // socket, AF_INET6, AF_INET, connect, getsockname @@ -70,10 +71,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); } << [&] { std::string cmd; - sock_.Recv(&cmd); + auto rc = sock_.Recv(&cmd); + if (!rc.OK()) { + return rc; + } jcmd = Json::Load(StringView{cmd}); cmd_ = static_cast(get(jcmd["cmd"])); - return Success(); + return rc; } << [&] { if (cmd_ == proto::CMD::kStart) { proto::Start start; @@ -100,14 +104,18 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA RabitTracker::RabitTracker(Json const& config) : Tracker{config} { std::string self; - auto rc = collective::GetHostAddress(&self); - host_ = OptionalArg(config, "host", self); + auto rc = Success() << [&] { + return collective::GetHostAddress(&self); + } << [&] { + host_ = OptionalArg(config, "host", self); - auto addr = MakeSockAddress(xgboost::StringView{host_}, 0); - listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6); - rc = listener_.Bind(host_, &this->port_); + auto addr = MakeSockAddress(xgboost::StringView{host_}, 0); + listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6); + return listener_.Bind(host_, &this->port_); + } << [&] { + return listener_.Listen(); + }; SafeColl(rc); - listener_.Listen(); } Result RabitTracker::Bootstrap(std::vector* p_workers) { @@ -220,9 +228,13 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { // // retry is set to 1, just let the worker timeout or error. Otherwise the // tracker and the worker might be waiting for each other. - auto rc = Connect(w.first, w.second, 1, timeout_, &out); + auto rc = Success() << [&] { + return Connect(w.first, w.second, 1, timeout_, &out); + } << [&] { + return proto::Error{}.SignalError(&out); + }; if (!rc.OK()) { - return Fail("Failed to inform workers to stop."); + return Fail("Failed to inform worker:" + w.first + " for error.", std::move(rc)); } } return Success(); @@ -231,13 +243,37 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { return std::async(std::launch::async, [this, handle_error] { State state{this->n_workers_}; + auto select_accept = [&](TCPSocket* sock, auto* addr) { + // accept with poll so that we can enable timeout and interruption. + rabit::utils::PollHelper poll; + auto rc = Success() << [&] { + std::lock_guard lock{listener_mu_}; + return listener_.NonBlocking(true); + } << [&] { + std::lock_guard lock{listener_mu_}; + poll.WatchRead(listener_); + if (state.running) { + // Don't timeout if the communicator group is up and running. + return poll.Poll(std::chrono::seconds{-1}); + } else { + // Have timeout for workers to bootstrap. + return poll.Poll(timeout_); + } + } << [&] { + // this->Stop() closes the socket with a lock. Therefore, when the accept returns + // due to shutdown, the state is still valid (closed). + return listener_.Accept(sock, addr); + }; + return rc; + }; + while (state.ShouldContinue()) { TCPSocket sock; SockAddress addr; this->ready_ = true; - auto rc = listener_.Accept(&sock, &addr); + auto rc = select_accept(&sock, &addr); if (!rc.OK()) { - return Fail("Failed to accept connection.", std::move(rc)); + return Fail("Failed to accept connection.", this->Stop() + std::move(rc)); } auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)}; @@ -252,7 +288,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { state.Error(); rc = handle_error(worker); if (!rc.OK()) { - return Fail("Failed to handle abort.", std::move(rc)); + return Fail("Failed to handle abort.", this->Stop() + std::move(rc)); } } @@ -262,7 +298,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { state.Bootstrap(); } if (!rc.OK()) { - return rc; + return this->Stop() + std::move(rc); } continue; } @@ -289,12 +325,11 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { } case proto::CMD::kInvalid: default: { - return Fail("Invalid command received."); + return Fail("Invalid command received.", this->Stop()); } } } - ready_ = false; - return Success(); + return this->Stop(); }); } @@ -303,11 +338,30 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { SafeColl(rc); Json args{Object{}}; - args["DMLC_TRACKER_URI"] = String{host_}; - args["DMLC_TRACKER_PORT"] = this->Port(); + args["dmlc_tracker_uri"] = String{host_}; + args["dmlc_tracker_port"] = this->Port(); return args; } +[[nodiscard]] Result RabitTracker::Stop() { + if (!this->Ready()) { + return Success(); + } + + ready_ = false; + std::lock_guard lock{listener_mu_}; + if (this->listener_.IsClosed()) { + return Success(); + } + + return Success() << [&] { + // This should have the effect of stopping the `accept` call. + return this->listener_.Shutdown(); + } << [&] { + return listener_.Close(); + }; +} + [[nodiscard]] Result GetHostAddress(std::string* out) { auto rc = GetHostName(out); if (!rc.OK()) { diff --git a/src/collective/tracker.h b/src/collective/tracker.h index e15aaee59..af30e0be7 100644 --- a/src/collective/tracker.h +++ b/src/collective/tracker.h @@ -36,15 +36,18 @@ namespace xgboost::collective { * signal an error to the tracker and the tracker will notify other workers. */ class Tracker { + public: + enum class SortBy : std::int8_t { + kHost = 0, + kTask = 1, + }; + protected: // How to sort the workers, either by host name or by task ID. When using a multi-GPU // setting, multiple workers can occupy the same host, in which case one should sort // workers by task. Due to compatibility reason, the task ID is not always available, so // we use host as the default. - enum class SortBy : std::int8_t { - kHost = 0, - kTask = 1, - } sortby_; + SortBy sortby_; protected: std::int32_t n_workers_{0}; @@ -54,10 +57,7 @@ class Tracker { public: explicit Tracker(Json const& config); - Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout) - : n_workers_{n_worders}, port_{port}, timeout_{timeout} {} - - virtual ~Tracker() noexcept(false){}; // NOLINT + virtual ~Tracker() = default; [[nodiscard]] Result WaitUntilReady() const; @@ -69,6 +69,11 @@ class Tracker { * @brief Flag to indicate whether the server is running. */ [[nodiscard]] bool Ready() const { return ready_; } + /** + * @brief Shutdown the tracker, cannot be restarted again. Useful when the tracker hangs while + * calling accept. + */ + virtual Result Stop() { return Success(); } }; class RabitTracker : public Tracker { @@ -127,28 +132,22 @@ class RabitTracker : public Tracker { // record for how to reach out to workers if error happens. std::vector> worker_error_handles_; // listening socket for incoming workers. - // - // At the moment, the listener calls accept without first polling. We can add an - // additional unix domain socket to allow cancelling the accept. TCPSocket listener_; + // mutex for protecting the listener, used to prevent race when it's listening while + // another thread tries to shut it down. + std::mutex listener_mu_; Result Bootstrap(std::vector* p_workers); public: - explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port, - std::chrono::seconds timeout) - : Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} { - listener_ = TCPSocket::Create(SockDomain::kV4); - auto rc = listener_.Bind(host, &this->port_); - CHECK(rc.OK()) << rc.Report(); - listener_.Listen(); - } - explicit RabitTracker(Json const& config); - ~RabitTracker() noexcept(false) override = default; + ~RabitTracker() override = default; std::future Run() override; [[nodiscard]] Json WorkerArgs() const override; + // Stop the tracker without waiting. This is to prevent the tracker from hanging when + // one of the workers failes to start. + [[nodiscard]] Result Stop() override; }; // Prob the public IP address of the host, need a better method. diff --git a/tests/cpp/collective/test_coll_c_api.cc b/tests/cpp/collective/test_coll_c_api.cc index d80fbc140..c7229ff77 100644 --- a/tests/cpp/collective/test_coll_c_api.cc +++ b/tests/cpp/collective/test_coll_c_api.cc @@ -25,13 +25,13 @@ TEST_F(TrackerAPITest, CAPI) { auto config_str = Json::Dump(config); auto rc = XGTrackerCreate(config_str.c_str(), &handle); ASSERT_EQ(rc, 0); - rc = XGTrackerRun(handle); + rc = XGTrackerRun(handle, nullptr); ASSERT_EQ(rc, 0); std::thread bg_wait{[&] { Json config{Object{}}; auto config_str = Json::Dump(config); - auto rc = XGTrackerWait(handle, config_str.c_str()); + auto rc = XGTrackerWaitFor(handle, config_str.c_str()); ASSERT_EQ(rc, 0); }}; @@ -42,8 +42,8 @@ TEST_F(TrackerAPITest, CAPI) { std::string host; ASSERT_TRUE(GetHostAddress(&host).OK()); - ASSERT_EQ(host, get(args["DMLC_TRACKER_URI"])); - auto port = get(args["DMLC_TRACKER_PORT"]); + ASSERT_EQ(host, get(args["dmlc_tracker_uri"])); + auto port = get(args["dmlc_tracker_port"]); ASSERT_NE(port, 0); std::vector workers; diff --git a/tests/cpp/collective/test_comm.cc b/tests/cpp/collective/test_comm.cc index 8e69b2f8e..c1eb06465 100644 --- a/tests/cpp/collective/test_comm.cc +++ b/tests/cpp/collective/test_comm.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include @@ -14,7 +14,7 @@ class CommTest : public TrackerTest {}; TEST_F(CommTest, Channel) { auto n_workers = 4; - RabitTracker tracker{host, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; auto fut = tracker.Run(); std::vector workers; @@ -29,7 +29,7 @@ TEST_F(CommTest, Channel) { return p_chan->SendAll( EraseType(common::Span{&i, static_cast(1)})); } << [&] { return p_chan->Block(); }; - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } else { auto p_chan = worker.Comm().Chan(i - 1); std::int32_t r{-1}; @@ -37,7 +37,7 @@ TEST_F(CommTest, Channel) { return p_chan->RecvAll( EraseType(common::Span{&r, static_cast(1)})); } << [&] { return p_chan->Block(); }; - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(r, i - 1); } }); diff --git a/tests/cpp/collective/test_comm_group.cc b/tests/cpp/collective/test_comm_group.cc index 0f6bc23a2..3b1b5c5df 100644 --- a/tests/cpp/collective/test_comm_group.cc +++ b/tests/cpp/collective/test_comm_group.cc @@ -17,17 +17,6 @@ namespace xgboost::collective { namespace { -auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { - Json config{Object{}}; - config["dmlc_communicator"] = std::string{"rabit"}; - config["DMLC_TRACKER_URI"] = host; - config["DMLC_TRACKER_PORT"] = port; - config["dmlc_timeout_sec"] = static_cast(timeout.count()); - config["DMLC_TASK_ID"] = std::to_string(r); - config["dmlc_retry"] = 2; - return config; -} - class CommGroupTest : public SocketTest {}; } // namespace @@ -36,7 +25,7 @@ TEST_F(CommGroupTest, Basic) { TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { Context ctx; - auto config = MakeConfig(host, port, timeout, r); + auto config = MakeDistributedTestConfig(host, port, timeout, r); std::unique_ptr ptr{CommGroup::Create(config)}; ASSERT_TRUE(ptr->IsDistributed()); ASSERT_EQ(ptr->World(), n_workers); @@ -52,7 +41,7 @@ TEST_F(CommGroupTest, BasicGPU) { TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { auto ctx = MakeCUDACtx(r); - auto config = MakeConfig(host, port, timeout, r); + auto config = MakeDistributedTestConfig(host, port, timeout, r); std::unique_ptr ptr{CommGroup::Create(config)}; auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0)); ASSERT_EQ(comm.TaskID(), std::to_string(r)); diff --git a/tests/cpp/collective/test_loop.cc b/tests/cpp/collective/test_loop.cc index 0908d9623..34e0c1de8 100644 --- a/tests/cpp/collective/test_loop.cc +++ b/tests/cpp/collective/test_loop.cc @@ -28,13 +28,11 @@ class LoopTest : public ::testing::Test { auto domain = SockDomain::kV4; pair_.first = TCPSocket::Create(domain); - in_port_t port{0}; + std::int32_t port{0}; auto rc = Success() << [&] { - port = pair_.first.BindHost(); - return Success(); + return pair_.first.BindHost(&port); } << [&] { - pair_.first.Listen(); - return Success(); + return pair_.first.Listen(); }; SafeColl(rc); diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc index ced795fef..ea57da9b4 100644 --- a/tests/cpp/collective/test_socket.cc +++ b/tests/cpp/collective/test_socket.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023, XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include #include @@ -21,14 +21,19 @@ TEST_F(SocketTest, Basic) { auto run_test = [msg](SockDomain domain) { auto server = TCPSocket::Create(domain); ASSERT_EQ(server.Domain(), domain); - auto port = server.BindHost(); - server.Listen(); + std::int32_t port{0}; + auto rc = Success() << [&] { + return server.BindHost(&port); + } << [&] { + return server.Listen(); + }; + SafeColl(rc); TCPSocket client; if (domain == SockDomain::kV4) { auto const& addr = SockAddrV4::Loopback().Addr(); auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } else { auto const& addr = SockAddrV6::Loopback().Addr(); auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client); @@ -45,7 +50,8 @@ TEST_F(SocketTest, Basic) { accepted.Send(msg); std::string str; - client.Recv(&str); + rc = client.Recv(&str); + SafeColl(rc); ASSERT_EQ(StringView{str}, msg); }; diff --git a/tests/cpp/collective/test_tracker.cc b/tests/cpp/collective/test_tracker.cc index 0dce33c0c..8d6cbeff2 100644 --- a/tests/cpp/collective/test_tracker.cc +++ b/tests/cpp/collective/test_tracker.cc @@ -1,6 +1,7 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ +#include #include #include // for seconds @@ -10,6 +11,7 @@ #include // for vector #include "../../../src/collective/comm.h" +#include "../helpers.h" // for GMockThrow #include "test_worker.h" namespace xgboost::collective { @@ -20,13 +22,13 @@ class PrintWorker : public WorkerForTest { void Print() { auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank())); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } }; } // namespace TEST_F(TrackerTest, Bootstrap) { - RabitTracker tracker{host, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; ASSERT_FALSE(tracker.Ready()); auto fut = tracker.Run(); @@ -34,7 +36,7 @@ TEST_F(TrackerTest, Bootstrap) { auto args = tracker.WorkerArgs(); ASSERT_TRUE(tracker.Ready()); - ASSERT_EQ(get(args["DMLC_TRACKER_URI"]), host); + ASSERT_EQ(get(args["dmlc_tracker_uri"]), host); std::int32_t port = tracker.Port(); @@ -44,12 +46,11 @@ TEST_F(TrackerTest, Bootstrap) { for (auto &w : workers) { w.join(); } - - ASSERT_TRUE(fut.get().OK()); + SafeColl(fut.get()); } TEST_F(TrackerTest, Print) { - RabitTracker tracker{host, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; auto fut = tracker.Run(); std::vector workers; @@ -73,4 +74,47 @@ TEST_F(TrackerTest, Print) { } TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); } + +/** + * Test connecting the tracker after it has finished. This should not hang the workers. + */ +TEST_F(TrackerTest, AfterShutdown) { + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; + auto fut = tracker.Run(); + + std::vector workers; + auto rc = tracker.WaitUntilReady(); + ASSERT_TRUE(rc.OK()); + + std::int32_t port = tracker.Port(); + + // Launch no-op workers to cause the tracker to shutdown. + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; }); + } + + for (auto &w : workers) { + w.join(); + } + + ASSERT_TRUE(fut.get().OK()); + + // Launch workers again, they should fail. + workers.clear(); + for (std::int32_t i = 0; i < n_workers; ++i) { + auto assert_that = [=] { + WorkerForTest worker{host, port, timeout, n_workers, i}; + }; + // On a Linux platform, the connection will be refused, on Apple platform, this gets + // an operation now in progress poll failure, on Windows, it's a timeout error. +#if defined(__linux__) + workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Connection refused")); }); +#else + workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Failed to connect to")); }); +#endif + } + for (auto &w : workers) { + w.join(); + } +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index 7b76052c8..c84df528f 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -37,7 +37,7 @@ class WorkerForTest { comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} { CHECK_EQ(world_size_, comm_.World()); } - virtual ~WorkerForTest() = default; + virtual ~WorkerForTest() noexcept(false) { SafeColl(comm_.Shutdown()); } auto& Comm() { return comm_; } void LimitSockBuf(std::int32_t n_bytes) { @@ -87,19 +87,30 @@ class TrackerTest : public SocketTest { void SetUp() override { SocketTest::SetUp(); auto rc = GetHostAddress(&host); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } }; +inline Json MakeTrackerConfig(std::string host, std::int32_t n_workers, + std::chrono::seconds timeout) { + Json config{Object{}}; + config["host"] = host; + config["port"] = Integer{0}; + config["n_workers"] = Integer{n_workers}; + config["sortby"] = Integer{static_cast(Tracker::SortBy::kHost)}; + config["timeout"] = timeout.count(); + return config; +} + template void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { std::chrono::seconds timeout{2}; std::string host; auto rc = GetHostAddress(&host); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); LOG(INFO) << "Using " << n_workers << " workers for test."; - RabitTracker tracker{StringView{host}, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; auto fut = tracker.Run(); std::vector workers; @@ -115,4 +126,15 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { ASSERT_TRUE(fut.get().OK()); } +inline auto MakeDistributedTestConfig(std::string host, std::int32_t port, + std::chrono::seconds timeout, std::int32_t r) { + Json config{Object{}}; + config["dmlc_communicator"] = std::string{"rabit"}; + config["dmlc_tracker_uri"] = host; + config["dmlc_tracker_port"] = port; + config["dmlc_timeout_sec"] = static_cast(timeout.count()); + config["dmlc_task_id"] = std::to_string(r); + config["dmlc_retry"] = 2; + return config; +} } // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_federated_tracker.cc b/tests/cpp/plugin/federated/test_federated_tracker.cc index 81ff95540..aa979ff15 100644 --- a/tests/cpp/plugin/federated/test_federated_tracker.cc +++ b/tests/cpp/plugin/federated/test_federated_tracker.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include @@ -8,7 +8,6 @@ #include "../../../../src/collective/tracker.h" // for GetHostAddress #include "federated_tracker.h" -#include "test_worker.h" #include "xgboost/json.h" // for Json namespace xgboost::collective { @@ -26,7 +25,7 @@ TEST(FederatedTrackerTest, Basic) { ASSERT_GE(tracker->Port(), 1); std::string host; auto rc = GetHostAddress(&host); - ASSERT_EQ(get(args["DMLC_TRACKER_URI"]), host); + ASSERT_EQ(get(args["dmlc_tracker_uri"]), host); rc = tracker->Shutdown(); ASSERT_TRUE(rc.OK());