[rabit] Improved connection handling. (#9531)

- Enable timeout.
- Report connection error from the system.
- Handle retry for both tracker connection and peer connection.
This commit is contained in:
Jiaming Yuan
2023-08-30 13:00:04 +08:00
committed by GitHub
parent 2462e22cd4
commit ccfc90e4c6
10 changed files with 463 additions and 130 deletions

View File

@@ -0,0 +1,160 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <memory> // for unique_ptr
#include <sstream> // for stringstream
#include <stack> // for stack
#include <string> // for string
#include <utility> // for move
namespace xgboost::collective {
namespace detail {
struct ResultImpl {
std::string message;
std::error_code errc{}; // optional for system error.
std::unique_ptr<ResultImpl> prev{nullptr};
ResultImpl() = delete; // must initialize.
ResultImpl(ResultImpl const& that) = delete;
ResultImpl(ResultImpl&& that) = default;
ResultImpl& operator=(ResultImpl const& that) = delete;
ResultImpl& operator=(ResultImpl&& that) = default;
explicit ResultImpl(std::string msg) : message{std::move(msg)} {}
explicit ResultImpl(std::string msg, std::error_code errc)
: message{std::move(msg)}, errc{std::move(errc)} {}
explicit ResultImpl(std::string msg, std::unique_ptr<ResultImpl> prev)
: message{std::move(msg)}, prev{std::move(prev)} {}
explicit ResultImpl(std::string msg, std::error_code errc, std::unique_ptr<ResultImpl> prev)
: message{std::move(msg)}, errc{std::move(errc)}, prev{std::move(prev)} {}
[[nodiscard]] bool operator==(ResultImpl const& that) const noexcept(true) {
if ((prev && !that.prev) || (!prev && that.prev)) {
// one of them doesn't have prev
return false;
}
auto cur_eq = message == that.message && errc == that.errc;
if (prev && that.prev) {
// recursive comparison
auto prev_eq = *prev == *that.prev;
return cur_eq && prev_eq;
}
return cur_eq;
}
[[nodiscard]] std::string Report() {
std::stringstream ss;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}
auto ptr = prev.get();
while (ptr) {
ss << "\n- ";
ss << ptr->message;
if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}
return ss.str();
}
[[nodiscard]] auto Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
};
} // namespace detail
/**
* @brief An error type that's easier to handle than throwing dmlc exception. We can
* record and propagate the system error code.
*/
struct Result {
private:
std::unique_ptr<detail::ResultImpl> impl_{nullptr};
public:
Result() noexcept(true) = default;
explicit Result(std::string msg) : impl_{std::make_unique<detail::ResultImpl>(std::move(msg))} {}
explicit Result(std::string msg, std::error_code errc)
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc))} {}
Result(std::string msg, Result&& prev)
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(prev.impl_))} {}
Result(std::string msg, std::error_code errc, Result&& prev)
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc),
std::move(prev.impl_))} {}
Result(Result const& that) = delete;
Result& operator=(Result const& that) = delete;
Result(Result&& that) = default;
Result& operator=(Result&& that) = default;
[[nodiscard]] bool OK() const noexcept(true) { return !impl_; }
[[nodiscard]] std::string Report() const { return OK() ? "" : impl_->Report(); }
/**
* @brief Return the root system error. This might return success if there's no system error.
*/
[[nodiscard]] auto Code() const { return OK() ? std::error_code{} : impl_->Code(); }
[[nodiscard]] bool operator==(Result const& that) const noexcept(true) {
if (OK() && that.OK()) {
return true;
}
if ((OK() && !that.OK()) || (!OK() && that.OK())) {
return false;
}
return *impl_ == *that.impl_;
}
};
/**
* @brief Return success.
*/
[[nodiscard]] inline auto Success() noexcept(true) { return Result{}; }
/**
* @brief Return failure.
*/
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
/**
* @brief Return failure with `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
return Result{std::move(msg), std::move(errc)};
}
/**
* @brief Return failure with a previous error.
*/
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
return Result{std::move(msg), std::forward<Result>(prev)};
}
/**
* @brief Return failure with a previous error and a new `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
}
} // namespace xgboost::collective

View File

@@ -56,9 +56,10 @@ using ssize_t = int;
#endif // defined(_WIN32)
#include "xgboost/base.h" // XGBOOST_EXPECT
#include "xgboost/logging.h" // LOG
#include "xgboost/string_view.h" // StringView
#include "xgboost/base.h" // XGBOOST_EXPECT
#include "xgboost/collective/result.h" // for Result
#include "xgboost/logging.h" // LOG
#include "xgboost/string_view.h" // StringView
#if !defined(HOST_NAME_MAX)
#define HOST_NAME_MAX 256 // macos
@@ -81,6 +82,10 @@ inline std::int32_t LastError() {
#endif
}
[[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
}
#if defined(__GLIBC__)
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
std::int32_t line = __builtin_LINE(),
@@ -120,15 +125,19 @@ inline std::int32_t CloseSocket(SocketT fd) {
#endif
}
inline bool LastErrorWouldBlock() {
int errsv = LastError();
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
return errsv == EAGAIN || errsv == EWOULDBLOCK;
return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
#endif // _WIN32
}
inline bool LastErrorWouldBlock() {
int errsv = LastError();
return ErrorWouldBlock(errsv);
}
inline void SocketStartup() {
#if defined(_WIN32)
WSADATA wsa_data;
@@ -315,23 +324,35 @@ class TCPSocket {
bool IsClosed() const { return handle_ == InvalidSocket(); }
/** \brief get last error code if any */
std::int32_t GetSockError() const {
std::int32_t error = 0;
socklen_t len = sizeof(error);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
return error;
Result GetSockError() const {
std::int32_t optval = 0;
socklen_t len = sizeof(optval);
auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
if (ret != 0) {
auto errc = std::error_code{system::LastError(), std::system_category()};
return Fail("Failed to retrieve socket error.", std::move(errc));
}
if (optval != 0) {
auto errc = std::error_code{optval, std::system_category()};
return Fail("Socket error.", std::move(errc));
}
return Success();
}
/** \brief check if anything bad happens */
bool BadSocket() const {
if (IsClosed()) return true;
std::int32_t err = GetSockError();
if (err == EBADF || err == EINTR) return true;
if (IsClosed()) {
return true;
}
auto err = GetSockError();
if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
return true;
}
return false;
}
void SetNonBlock() {
bool non_block{true};
void SetNonBlock(bool non_block) {
#if defined(_WIN32)
u_long mode = non_block ? 1 : 0;
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
@@ -530,10 +551,20 @@ class TCPSocket {
};
/**
* \brief Connect to remote address, returns the error code if failed (no exception is
* raised so that we can retry).
* @brief Connect to remote address, returns the error code if failed.
*
* @param host Host IP address.
* @param port Connection port.
* @param retry Number of retries to attempt.
* @param timeout Timeout of each connection attempt.
* @param out_conn Output socket if the connection is successful. Value is invalid and undefined if
* the connection failed.
*
* @return Connection status.
*/
std::error_code Connect(SockAddress const &addr, TCPSocket *out);
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
std::chrono::seconds timeout,
xgboost::collective::TCPSocket *out_conn);
/**
* \brief Get the local host name.