[coll] Implement shutdown for tracker and comm. (#10208)

- Force shutdown the tracker.
- Implement shutdown notice for error handling thread in comm.
This commit is contained in:
Jiaming Yuan
2024-04-20 04:08:17 +08:00
committed by GitHub
parent 8fb05c8c95
commit 3fbb221fec
24 changed files with 553 additions and 199 deletions

View File

@@ -1555,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
/**
* @brief Get the arguments needed for running workers. This should be called after
* XGTrackerRun() and XGTrackerWait()
* XGTrackerRun().
*
* @param handle The handle to the tracker.
* @param args The arguments returned as a JSON document.
@@ -1565,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
/**
* @brief Run the tracker.
* @brief Start the tracker. The tracker runs in the background and this function returns
* once the tracker is started.
*
* @param handle The handle to the tracker.
* @param config Unused at the moment, preserved for the future.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerRun(TrackerHandle handle);
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);
/**
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
* @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
* function will block until the tracker task is finished or timeout is reached.
*
* @param handle The handle to the tracker.
* @param config JSON encoded configuration. No argument is required yet, preserved for
@@ -1582,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle);
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
/**
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
* cannot close properly, manual interruption is required.
* @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
* tracker is not properly waited, this function will shutdown all connections with
* the tracker, potentially leading to undefined behavior.
*
* @param handle The handle to the tracker.
*

View File

@@ -124,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
#endif
}
inline std::int32_t ShutdownSocket(SocketT fd) {
#if defined(_WIN32)
auto rc = shutdown(fd, SD_BOTH);
if (rc != 0 && LastError() == WSANOTINITIALISED) {
return 0;
}
#else
auto rc = shutdown(fd, SHUT_RDWR);
if (rc != 0 && LastError() == ENOTCONN) {
return 0;
}
#endif
return rc;
}
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
@@ -499,36 +514,49 @@ class TCPSocket {
*/
[[nodiscard]] HandleT const &Handle() const { return handle_; }
/**
* \brief Listen to incoming requests. Should be called after bind.
* @brief Listen to incoming requests. Should be called after bind.
*/
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
if (listen(handle_, backlog) != 0) {
return system::FailWithCode("Failed to listen.");
}
return Success();
}
/**
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
* @brief Bind socket to INADDR_ANY, return the port selected by the OS.
*/
[[nodiscard]] in_port_t BindHost() {
[[nodiscard]] Result BindHost(std::int32_t* p_out) {
// Use int32 instead of in_port_t for consistency. We take port as parameter from
// users using other languages, the port is usually stored and passed around as int.
if (Domain() == SockDomain::kV6) {
auto addr = SockAddrV6::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
return system::FailWithCode("bind failed.");
}
sockaddr_in6 res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin6_port);
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
return system::FailWithCode("getsockname failed.");
}
*p_out = ntohs(res_addr.sin6_port);
} else {
auto addr = SockAddrV4::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
return system::FailWithCode("bind failed.");
}
sockaddr_in res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin_port);
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
return system::FailWithCode("getsockname failed.");
}
*p_out = ntohs(res_addr.sin_port);
}
return Success();
}
[[nodiscard]] auto Port() const {
@@ -641,13 +669,13 @@ class TCPSocket {
*/
std::size_t Send(StringView str);
/**
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
* @brief Receive string, format is matched with the Python socket wrapper in RABIT.
*/
std::size_t Recv(std::string *p_str);
[[nodiscard]] Result Recv(std::string *p_str);
/**
* @brief Close the socket, called automatically in destructor if the socket is not closed.
*/
Result Close() {
[[nodiscard]] Result Close() {
if (InvalidSocket() != handle_) {
auto rc = system::CloseSocket(handle_);
#if defined(_WIN32)
@@ -664,6 +692,25 @@ class TCPSocket {
}
return Success();
}
/**
* @brief Call shutdown on the socket.
*/
[[nodiscard]] Result Shutdown() {
if (this->IsClosed()) {
return Success();
}
auto rc = system::ShutdownSocket(this->Handle());
#if defined(_WIN32)
// Windows cannot shutdown a socket if it's not connected.
if (rc == -1 && system::LastError() == WSAENOTCONN) {
return Success();
}
#endif
if (rc != 0) {
return system::FailWithCode("Failed to shutdown socket.");
}
return Success();
}
/**
* \brief Create a TCP socket on specified domain.