[rabit] Improved connection handling. (#9531)
- Enable timeout. - Report connection error from the system. - Handle retry for both tracker connection and peer connection.
This commit is contained in:
160
include/xgboost/collective/result.h
Normal file
160
include/xgboost/collective/result.h
Normal file
@@ -0,0 +1,160 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <memory> // for unique_ptr
|
||||
#include <sstream> // for stringstream
|
||||
#include <stack> // for stack
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace detail {
|
||||
struct ResultImpl {
|
||||
std::string message;
|
||||
std::error_code errc{}; // optional for system error.
|
||||
|
||||
std::unique_ptr<ResultImpl> prev{nullptr};
|
||||
|
||||
ResultImpl() = delete; // must initialize.
|
||||
ResultImpl(ResultImpl const& that) = delete;
|
||||
ResultImpl(ResultImpl&& that) = default;
|
||||
ResultImpl& operator=(ResultImpl const& that) = delete;
|
||||
ResultImpl& operator=(ResultImpl&& that) = default;
|
||||
|
||||
explicit ResultImpl(std::string msg) : message{std::move(msg)} {}
|
||||
explicit ResultImpl(std::string msg, std::error_code errc)
|
||||
: message{std::move(msg)}, errc{std::move(errc)} {}
|
||||
explicit ResultImpl(std::string msg, std::unique_ptr<ResultImpl> prev)
|
||||
: message{std::move(msg)}, prev{std::move(prev)} {}
|
||||
explicit ResultImpl(std::string msg, std::error_code errc, std::unique_ptr<ResultImpl> prev)
|
||||
: message{std::move(msg)}, errc{std::move(errc)}, prev{std::move(prev)} {}
|
||||
|
||||
[[nodiscard]] bool operator==(ResultImpl const& that) const noexcept(true) {
|
||||
if ((prev && !that.prev) || (!prev && that.prev)) {
|
||||
// one of them doesn't have prev
|
||||
return false;
|
||||
}
|
||||
|
||||
auto cur_eq = message == that.message && errc == that.errc;
|
||||
if (prev && that.prev) {
|
||||
// recursive comparison
|
||||
auto prev_eq = *prev == *that.prev;
|
||||
return cur_eq && prev_eq;
|
||||
}
|
||||
return cur_eq;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string Report() {
|
||||
std::stringstream ss;
|
||||
ss << "\n- " << this->message;
|
||||
if (this->errc != std::error_code{}) {
|
||||
ss << " system error:" << this->errc.message();
|
||||
}
|
||||
|
||||
auto ptr = prev.get();
|
||||
while (ptr) {
|
||||
ss << "\n- ";
|
||||
ss << ptr->message;
|
||||
|
||||
if (ptr->errc != std::error_code{}) {
|
||||
ss << " " << ptr->errc.message();
|
||||
}
|
||||
ptr = ptr->prev.get();
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
[[nodiscard]] auto Code() const {
|
||||
// Find the root error.
|
||||
std::stack<ResultImpl const*> stack;
|
||||
auto ptr = this;
|
||||
while (ptr) {
|
||||
stack.push(ptr);
|
||||
if (ptr->prev) {
|
||||
ptr = ptr->prev.get();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (!stack.empty()) {
|
||||
auto frame = stack.top();
|
||||
stack.pop();
|
||||
if (frame->errc != std::error_code{}) {
|
||||
return frame->errc;
|
||||
}
|
||||
}
|
||||
return std::error_code{};
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* @brief An error type that's easier to handle than throwing dmlc exception. We can
|
||||
* record and propagate the system error code.
|
||||
*/
|
||||
struct Result {
|
||||
private:
|
||||
std::unique_ptr<detail::ResultImpl> impl_{nullptr};
|
||||
|
||||
public:
|
||||
Result() noexcept(true) = default;
|
||||
explicit Result(std::string msg) : impl_{std::make_unique<detail::ResultImpl>(std::move(msg))} {}
|
||||
explicit Result(std::string msg, std::error_code errc)
|
||||
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc))} {}
|
||||
Result(std::string msg, Result&& prev)
|
||||
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(prev.impl_))} {}
|
||||
Result(std::string msg, std::error_code errc, Result&& prev)
|
||||
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc),
|
||||
std::move(prev.impl_))} {}
|
||||
|
||||
Result(Result const& that) = delete;
|
||||
Result& operator=(Result const& that) = delete;
|
||||
Result(Result&& that) = default;
|
||||
Result& operator=(Result&& that) = default;
|
||||
|
||||
[[nodiscard]] bool OK() const noexcept(true) { return !impl_; }
|
||||
[[nodiscard]] std::string Report() const { return OK() ? "" : impl_->Report(); }
|
||||
/**
|
||||
* @brief Return the root system error. This might return success if there's no system error.
|
||||
*/
|
||||
[[nodiscard]] auto Code() const { return OK() ? std::error_code{} : impl_->Code(); }
|
||||
[[nodiscard]] bool operator==(Result const& that) const noexcept(true) {
|
||||
if (OK() && that.OK()) {
|
||||
return true;
|
||||
}
|
||||
if ((OK() && !that.OK()) || (!OK() && that.OK())) {
|
||||
return false;
|
||||
}
|
||||
return *impl_ == *that.impl_;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Return success.
|
||||
*/
|
||||
[[nodiscard]] inline auto Success() noexcept(true) { return Result{}; }
|
||||
/**
|
||||
* @brief Return failure.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
|
||||
/**
|
||||
* @brief Return failure with `errno`.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
|
||||
return Result{std::move(msg), std::move(errc)};
|
||||
}
|
||||
/**
|
||||
* @brief Return failure with a previous error.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
|
||||
return Result{std::move(msg), std::forward<Result>(prev)};
|
||||
}
|
||||
/**
|
||||
* @brief Return failure with a previous error and a new `errno`.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
|
||||
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -56,9 +56,10 @@ using ssize_t = int;
|
||||
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include "xgboost/base.h" // XGBOOST_EXPECT
|
||||
#include "xgboost/logging.h" // LOG
|
||||
#include "xgboost/string_view.h" // StringView
|
||||
#include "xgboost/base.h" // XGBOOST_EXPECT
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/logging.h" // LOG
|
||||
#include "xgboost/string_view.h" // StringView
|
||||
|
||||
#if !defined(HOST_NAME_MAX)
|
||||
#define HOST_NAME_MAX 256 // macos
|
||||
@@ -81,6 +82,10 @@ inline std::int32_t LastError() {
|
||||
#endif
|
||||
}
|
||||
|
||||
[[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
|
||||
return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
|
||||
}
|
||||
|
||||
#if defined(__GLIBC__)
|
||||
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
|
||||
std::int32_t line = __builtin_LINE(),
|
||||
@@ -120,15 +125,19 @@ inline std::int32_t CloseSocket(SocketT fd) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool LastErrorWouldBlock() {
|
||||
int errsv = LastError();
|
||||
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
|
||||
#ifdef _WIN32
|
||||
return errsv == WSAEWOULDBLOCK;
|
||||
#else
|
||||
return errsv == EAGAIN || errsv == EWOULDBLOCK;
|
||||
return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
|
||||
#endif // _WIN32
|
||||
}
|
||||
|
||||
inline bool LastErrorWouldBlock() {
|
||||
int errsv = LastError();
|
||||
return ErrorWouldBlock(errsv);
|
||||
}
|
||||
|
||||
inline void SocketStartup() {
|
||||
#if defined(_WIN32)
|
||||
WSADATA wsa_data;
|
||||
@@ -315,23 +324,35 @@ class TCPSocket {
|
||||
bool IsClosed() const { return handle_ == InvalidSocket(); }
|
||||
|
||||
/** \brief get last error code if any */
|
||||
std::int32_t GetSockError() const {
|
||||
std::int32_t error = 0;
|
||||
socklen_t len = sizeof(error);
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
|
||||
return error;
|
||||
Result GetSockError() const {
|
||||
std::int32_t optval = 0;
|
||||
socklen_t len = sizeof(optval);
|
||||
auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
|
||||
if (ret != 0) {
|
||||
auto errc = std::error_code{system::LastError(), std::system_category()};
|
||||
return Fail("Failed to retrieve socket error.", std::move(errc));
|
||||
}
|
||||
if (optval != 0) {
|
||||
auto errc = std::error_code{optval, std::system_category()};
|
||||
return Fail("Socket error.", std::move(errc));
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
/** \brief check if anything bad happens */
|
||||
bool BadSocket() const {
|
||||
if (IsClosed()) return true;
|
||||
std::int32_t err = GetSockError();
|
||||
if (err == EBADF || err == EINTR) return true;
|
||||
if (IsClosed()) {
|
||||
return true;
|
||||
}
|
||||
auto err = GetSockError();
|
||||
if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
|
||||
err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void SetNonBlock() {
|
||||
bool non_block{true};
|
||||
void SetNonBlock(bool non_block) {
|
||||
#if defined(_WIN32)
|
||||
u_long mode = non_block ? 1 : 0;
|
||||
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
|
||||
@@ -530,10 +551,20 @@ class TCPSocket {
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Connect to remote address, returns the error code if failed (no exception is
|
||||
* raised so that we can retry).
|
||||
* @brief Connect to remote address, returns the error code if failed.
|
||||
*
|
||||
* @param host Host IP address.
|
||||
* @param port Connection port.
|
||||
* @param retry Number of retries to attempt.
|
||||
* @param timeout Timeout of each connection attempt.
|
||||
* @param out_conn Output socket if the connection is successful. Value is invalid and undefined if
|
||||
* the connection failed.
|
||||
*
|
||||
* @return Connection status.
|
||||
*/
|
||||
std::error_code Connect(SockAddress const &addr, TCPSocket *out);
|
||||
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||
std::chrono::seconds timeout,
|
||||
xgboost::collective::TCPSocket *out_conn);
|
||||
|
||||
/**
|
||||
* \brief Get the local host name.
|
||||
|
||||
Reference in New Issue
Block a user