[rabit] Improved connection handling. (#9531)

- Enable timeout.
- Report connection error from the system.
- Handle retry for both tracker connection and peer connection.
This commit is contained in:
Jiaming Yuan
2023-08-30 13:00:04 +08:00
committed by GitHub
parent 2462e22cd4
commit ccfc90e4c6
10 changed files with 463 additions and 130 deletions

View File

@@ -1,19 +1,22 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
/**
* Copyright 2022-2023 by XGBoost Contributors
*/
#include "xgboost/collective/socket.h"
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <cstring> // std::memcpy, std::memset
#include <filesystem> // for path
#include <system_error> // std::error_code, std::system_category
#include "rabit/internal/socket.h" // for PollHelper
#include "xgboost/collective/result.h" // for Result
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // getaddrinfo, freeaddrinfo
#endif // defined(__unix__) || defined(__APPLE__)
namespace xgboost {
namespace collective {
namespace xgboost::collective {
SockAddress MakeSockAddress(StringView host, in_port_t port) {
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
@@ -71,7 +74,12 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
return bytes;
}
std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
std::chrono::seconds timeout,
xgboost::collective::TCPSocket *out_conn) {
auto addr = MakeSockAddress(xgboost::StringView{host}, port);
auto &conn = *out_conn;
sockaddr const *addr_handle{nullptr};
socklen_t addr_len{0};
if (addr.IsV4()) {
@@ -81,14 +89,67 @@ std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
addr_len = sizeof(addr.V6().Handle());
}
auto socket = TCPSocket::Create(addr.Domain());
CHECK_EQ(static_cast<std::int32_t>(socket.Domain()), static_cast<std::int32_t>(addr.Domain()));
auto rc = connect(socket.Handle(), addr_handle, addr_len);
if (rc != 0) {
return std::error_code{errno, std::system_category()};
conn = TCPSocket::Create(addr.Domain());
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
conn.SetNonBlock(true);
Result last_error;
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
last_error = std::move(err);
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
};
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
if (attempt > 0) {
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
#if defined(_MSC_VER) || defined(__MINGW32__)
Sleep(attempt << 1);
#else
sleep(attempt << 1);
#endif
}
auto rc = connect(conn.Handle(), addr_handle, addr_len);
if (rc != 0) {
auto errcode = system::LastError();
if (!system::ErrorWouldBlock(errcode)) {
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
__FILE__, __LINE__);
continue;
}
rabit::utils::PollHelper poll;
poll.WatchWrite(conn);
auto result = poll.Poll(timeout);
if (!result.OK()) {
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
if (!poll.CheckWrite(conn)) {
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}),
__FILE__, __LINE__);
continue;
}
result = conn.GetSockError();
if (!result.OK()) {
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
conn.SetNonBlock(false);
return Success();
} else {
conn.SetNonBlock(false);
return Success();
}
}
*out = std::move(socket);
return std::make_error_code(std::errc{});
std::stringstream ss;
ss << "Failed to connect to " << host << ":" << port;
conn.Close();
return Fail(ss.str(), std::move(last_error));
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective