[coll] Implement shutdown for tracker and comm. (#10208)
- Force shutdown the tracker. - Implement shutdown notice for error handling thread in comm.
This commit is contained in:
@@ -1555,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
||||
|
||||
/**
|
||||
* @brief Get the arguments needed for running workers. This should be called after
|
||||
* XGTrackerRun() and XGTrackerWait()
|
||||
* XGTrackerRun().
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param args The arguments returned as a JSON document.
|
||||
@@ -1565,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
||||
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
|
||||
|
||||
/**
|
||||
* @brief Run the tracker.
|
||||
* @brief Start the tracker. The tracker runs in the background and this function returns
|
||||
* once the tracker is started.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param config Unused at the moment, preserved for the future.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);
|
||||
|
||||
/**
|
||||
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
|
||||
* @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
|
||||
* function will block until the tracker task is finished or timeout is reached.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param config JSON encoded configuration. No argument is required yet, preserved for
|
||||
@@ -1582,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
|
||||
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
|
||||
|
||||
/**
|
||||
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
|
||||
* cannot close properly, manual interruption is required.
|
||||
* @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
|
||||
* tracker is not properly waited, this function will shutdown all connections with
|
||||
* the tracker, potentially leading to undefined behavior.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
*
|
||||
|
||||
@@ -124,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline std::int32_t ShutdownSocket(SocketT fd) {
|
||||
#if defined(_WIN32)
|
||||
auto rc = shutdown(fd, SD_BOTH);
|
||||
if (rc != 0 && LastError() == WSANOTINITIALISED) {
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
auto rc = shutdown(fd, SHUT_RDWR);
|
||||
if (rc != 0 && LastError() == ENOTCONN) {
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
return rc;
|
||||
}
|
||||
|
||||
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
|
||||
#ifdef _WIN32
|
||||
return errsv == WSAEWOULDBLOCK;
|
||||
@@ -499,36 +514,49 @@ class TCPSocket {
|
||||
*/
|
||||
[[nodiscard]] HandleT const &Handle() const { return handle_; }
|
||||
/**
|
||||
* \brief Listen to incoming requests. Should be called after bind.
|
||||
* @brief Listen to incoming requests. Should be called after bind.
|
||||
*/
|
||||
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
|
||||
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
|
||||
if (listen(handle_, backlog) != 0) {
|
||||
return system::FailWithCode("Failed to listen.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
/**
|
||||
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
||||
* @brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
||||
*/
|
||||
[[nodiscard]] in_port_t BindHost() {
|
||||
[[nodiscard]] Result BindHost(std::int32_t* p_out) {
|
||||
// Use int32 instead of in_port_t for consistency. We take port as parameter from
|
||||
// users using other languages, the port is usually stored and passed around as int.
|
||||
if (Domain() == SockDomain::kV6) {
|
||||
auto addr = SockAddrV6::InaddrAny();
|
||||
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
||||
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
|
||||
return system::FailWithCode("bind failed.");
|
||||
}
|
||||
|
||||
sockaddr_in6 res_addr;
|
||||
socklen_t addrlen = sizeof(res_addr);
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
||||
return ntohs(res_addr.sin6_port);
|
||||
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
|
||||
return system::FailWithCode("getsockname failed.");
|
||||
}
|
||||
*p_out = ntohs(res_addr.sin6_port);
|
||||
} else {
|
||||
auto addr = SockAddrV4::InaddrAny();
|
||||
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
||||
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
|
||||
return system::FailWithCode("bind failed.");
|
||||
}
|
||||
|
||||
sockaddr_in res_addr;
|
||||
socklen_t addrlen = sizeof(res_addr);
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
||||
return ntohs(res_addr.sin_port);
|
||||
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
|
||||
return system::FailWithCode("getsockname failed.");
|
||||
}
|
||||
*p_out = ntohs(res_addr.sin_port);
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] auto Port() const {
|
||||
@@ -641,13 +669,13 @@ class TCPSocket {
|
||||
*/
|
||||
std::size_t Send(StringView str);
|
||||
/**
|
||||
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
|
||||
* @brief Receive string, format is matched with the Python socket wrapper in RABIT.
|
||||
*/
|
||||
std::size_t Recv(std::string *p_str);
|
||||
[[nodiscard]] Result Recv(std::string *p_str);
|
||||
/**
|
||||
* @brief Close the socket, called automatically in destructor if the socket is not closed.
|
||||
*/
|
||||
Result Close() {
|
||||
[[nodiscard]] Result Close() {
|
||||
if (InvalidSocket() != handle_) {
|
||||
auto rc = system::CloseSocket(handle_);
|
||||
#if defined(_WIN32)
|
||||
@@ -664,6 +692,25 @@ class TCPSocket {
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
/**
|
||||
* @brief Call shutdown on the socket.
|
||||
*/
|
||||
[[nodiscard]] Result Shutdown() {
|
||||
if (this->IsClosed()) {
|
||||
return Success();
|
||||
}
|
||||
auto rc = system::ShutdownSocket(this->Handle());
|
||||
#if defined(_WIN32)
|
||||
// Windows cannot shutdown a socket if it's not connected.
|
||||
if (rc == -1 && system::LastError() == WSAENOTCONN) {
|
||||
return Success();
|
||||
}
|
||||
#endif
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to shutdown socket.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create a TCP socket on specified domain.
|
||||
|
||||
Reference in New Issue
Block a user