[rabit] Improved connection handling. (#9531)
- Enable timeout. - Report connection error from the system. - Handle retry for both tracker connection and peer connection.
This commit is contained in:
@@ -1,19 +1,22 @@
|
||||
/*!
|
||||
* Copyright (c) 2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include "xgboost/collective/socket.h"
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cstring> // std::memcpy, std::memset
|
||||
#include <filesystem> // for path
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
|
||||
#include "rabit/internal/socket.h" // for PollHelper
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
#if defined(__unix__) || defined(__APPLE__)
|
||||
#include <netdb.h> // getaddrinfo, freeaddrinfo
|
||||
#endif // defined(__unix__) || defined(__APPLE__)
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
namespace xgboost::collective {
|
||||
SockAddress MakeSockAddress(StringView host, in_port_t port) {
|
||||
struct addrinfo hints;
|
||||
std::memset(&hints, 0, sizeof(hints));
|
||||
@@ -71,7 +74,12 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
return bytes;
|
||||
}
|
||||
|
||||
std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
|
||||
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||
std::chrono::seconds timeout,
|
||||
xgboost::collective::TCPSocket *out_conn) {
|
||||
auto addr = MakeSockAddress(xgboost::StringView{host}, port);
|
||||
auto &conn = *out_conn;
|
||||
|
||||
sockaddr const *addr_handle{nullptr};
|
||||
socklen_t addr_len{0};
|
||||
if (addr.IsV4()) {
|
||||
@@ -81,14 +89,67 @@ std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
|
||||
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
|
||||
addr_len = sizeof(addr.V6().Handle());
|
||||
}
|
||||
auto socket = TCPSocket::Create(addr.Domain());
|
||||
CHECK_EQ(static_cast<std::int32_t>(socket.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||
auto rc = connect(socket.Handle(), addr_handle, addr_len);
|
||||
if (rc != 0) {
|
||||
return std::error_code{errno, std::system_category()};
|
||||
|
||||
conn = TCPSocket::Create(addr.Domain());
|
||||
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||
conn.SetNonBlock(true);
|
||||
|
||||
Result last_error;
|
||||
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
|
||||
last_error = std::move(err);
|
||||
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
|
||||
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
|
||||
};
|
||||
|
||||
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
||||
if (attempt > 0) {
|
||||
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
Sleep(attempt << 1);
|
||||
#else
|
||||
sleep(attempt << 1);
|
||||
#endif
|
||||
}
|
||||
|
||||
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
||||
if (rc != 0) {
|
||||
auto errcode = system::LastError();
|
||||
if (!system::ErrorWouldBlock(errcode)) {
|
||||
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
rabit::utils::PollHelper poll;
|
||||
poll.WatchWrite(conn);
|
||||
auto result = poll.Poll(timeout);
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
if (!poll.CheckWrite(conn)) {
|
||||
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
result = conn.GetSockError();
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
conn.SetNonBlock(false);
|
||||
return Success();
|
||||
|
||||
} else {
|
||||
conn.SetNonBlock(false);
|
||||
return Success();
|
||||
}
|
||||
}
|
||||
*out = std::move(socket);
|
||||
return std::make_error_code(std::errc{});
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Failed to connect to " << host << ":" << port;
|
||||
conn.Close();
|
||||
return Fail(ss.str(), std::move(last_error));
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user