197 lines
6.3 KiB
C++
197 lines
6.3 KiB
C++
/**
|
|
* Copyright 2022-2024, XGBoost Contributors
|
|
*/
|
|
#include "xgboost/collective/socket.h"
|
|
|
|
#include <array> // for array
|
|
#include <cstddef> // std::size_t
|
|
#include <cstdint> // std::int32_t
|
|
#include <cstring> // std::memcpy, std::memset
|
|
#include <filesystem> // for path
|
|
#include <system_error> // for error_code, system_category
|
|
#include <thread> // for sleep_for
|
|
|
|
#include "xgboost/collective/poll_utils.h" // for PollHelper
|
|
#include "xgboost/collective/result.h" // for Result
|
|
|
|
#if defined(__unix__) || defined(__APPLE__)
|
|
#include <netdb.h> // getaddrinfo, freeaddrinfo
|
|
#endif // defined(__unix__) || defined(__APPLE__)
|
|
|
|
namespace xgboost::collective {
|
|
SockAddress MakeSockAddress(StringView host, in_port_t port) {
|
|
struct addrinfo hints;
|
|
std::memset(&hints, 0, sizeof(hints));
|
|
hints.ai_protocol = SOCK_STREAM;
|
|
struct addrinfo *res = nullptr;
|
|
int sig = getaddrinfo(host.c_str(), nullptr, &hints, &res);
|
|
if (sig != 0) {
|
|
return {};
|
|
}
|
|
if (res->ai_family == static_cast<std::int32_t>(SockDomain::kV4)) {
|
|
sockaddr_in addr;
|
|
std::memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
|
addr.sin_port = htons(port);
|
|
auto v = SockAddrV4{addr};
|
|
freeaddrinfo(res);
|
|
return SockAddress{v};
|
|
} else if (res->ai_family == static_cast<std::int32_t>(SockDomain::kV6)) {
|
|
sockaddr_in6 addr;
|
|
std::memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
|
|
|
addr.sin6_port = htons(port);
|
|
auto v = SockAddrV6{addr};
|
|
freeaddrinfo(res);
|
|
return SockAddress{v};
|
|
} else {
|
|
LOG(FATAL) << "Failed to get addr info for: " << host;
|
|
}
|
|
|
|
return SockAddress{};
|
|
}
|
|
|
|
SockAddrV4 SockAddrV4::Loopback() { return MakeSockAddress("127.0.0.1", 0).V4(); }
|
|
SockAddrV4 SockAddrV4::InaddrAny() { return MakeSockAddress("0.0.0.0", 0).V4(); }
|
|
|
|
SockAddrV6 SockAddrV6::Loopback() { return MakeSockAddress("::1", 0).V6(); }
|
|
SockAddrV6 SockAddrV6::InaddrAny() { return MakeSockAddress("::", 0).V6(); }
|
|
|
|
std::size_t TCPSocket::Send(StringView str) {
|
|
CHECK(!this->IsClosed());
|
|
CHECK_LT(str.size(), std::numeric_limits<std::int32_t>::max());
|
|
std::int32_t len = static_cast<std::int32_t>(str.size());
|
|
std::size_t n_bytes{0};
|
|
auto rc = Success() << [&] {
|
|
return this->SendAll(&len, sizeof(len), &n_bytes);
|
|
} << [&] {
|
|
if (n_bytes != sizeof(len)) {
|
|
return Fail("Failed to send string length.");
|
|
}
|
|
return Success();
|
|
} << [&] {
|
|
return this->SendAll(str.c_str(), str.size(), &n_bytes);
|
|
} << [&] {
|
|
if (n_bytes != str.size()) {
|
|
return Fail("Failed to send string.");
|
|
}
|
|
return Success();
|
|
};
|
|
SafeColl(rc);
|
|
return n_bytes;
|
|
}
|
|
|
|
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
|
|
CHECK(!this->IsClosed());
|
|
std::int32_t len;
|
|
std::size_t n_bytes{0};
|
|
return Success() << [&] {
|
|
return this->RecvAll(&len, sizeof(len), &n_bytes);
|
|
} << [&] {
|
|
if (n_bytes != sizeof(len)) {
|
|
return Fail("Failed to recv string length.");
|
|
}
|
|
return Success();
|
|
} << [&] {
|
|
p_str->resize(len);
|
|
return this->RecvAll(&(*p_str)[0], len, &n_bytes);
|
|
} << [&] {
|
|
if (static_cast<std::remove_reference_t<decltype(len)>>(n_bytes) != len) {
|
|
return Fail("Failed to recv string.");
|
|
}
|
|
return Success();
|
|
};
|
|
}
|
|
|
|
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
|
std::chrono::seconds timeout,
|
|
xgboost::collective::TCPSocket *out_conn) {
|
|
auto addr = MakeSockAddress(xgboost::StringView{host}, port);
|
|
auto &conn = *out_conn;
|
|
|
|
sockaddr const *addr_handle{nullptr};
|
|
socklen_t addr_len{0};
|
|
if (addr.IsV4()) {
|
|
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V4().Handle());
|
|
addr_len = sizeof(addr.V4().Handle());
|
|
} else {
|
|
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
|
|
addr_len = sizeof(addr.V6().Handle());
|
|
}
|
|
|
|
conn = TCPSocket::Create(addr.Domain());
|
|
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
|
auto non_blocking = conn.NonBlocking();
|
|
auto rc = conn.NonBlocking(true);
|
|
if (!rc.OK()) {
|
|
return Fail("Failed to set socket option.", std::move(rc));
|
|
}
|
|
|
|
Result last_error;
|
|
auto log_failure = [&host, &last_error, port](Result err, char const *file, std::int32_t line) {
|
|
last_error = std::move(err);
|
|
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
|
|
<< "): Failed to connect to:" << host << ":" << port
|
|
<< " Error:" << last_error.Report();
|
|
};
|
|
|
|
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
|
if (attempt > 0) {
|
|
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
|
|
std::this_thread::sleep_for(std::chrono::seconds{attempt << 1});
|
|
}
|
|
|
|
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
|
if (rc == 0) {
|
|
return conn.NonBlocking(non_blocking);
|
|
}
|
|
|
|
auto errcode = system::LastError();
|
|
if (!system::ErrorWouldBlock(errcode)) {
|
|
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
|
|
__FILE__, __LINE__);
|
|
continue;
|
|
}
|
|
|
|
rabit::utils::PollHelper poll;
|
|
poll.WatchWrite(conn);
|
|
auto result = poll.Poll(timeout);
|
|
if (!result.OK()) {
|
|
// poll would fail if there's a socket error, we log the root cause instead of the
|
|
// poll failure.
|
|
auto sockerr = conn.GetSockError();
|
|
if (!sockerr.OK()) {
|
|
result = std::move(sockerr);
|
|
}
|
|
log_failure(std::move(result), __FILE__, __LINE__);
|
|
continue;
|
|
}
|
|
if (!poll.CheckWrite(conn)) {
|
|
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}), __FILE__,
|
|
__LINE__);
|
|
continue;
|
|
}
|
|
result = conn.GetSockError();
|
|
if (!result.OK()) {
|
|
log_failure(std::move(result), __FILE__, __LINE__);
|
|
continue;
|
|
}
|
|
|
|
return conn.NonBlocking(non_blocking);
|
|
}
|
|
|
|
std::stringstream ss;
|
|
ss << "Failed to connect to " << host << ":" << port;
|
|
auto close_rc = conn.Close();
|
|
return Fail(ss.str(), std::move(close_rc) + std::move(last_error));
|
|
}
|
|
|
|
[[nodiscard]] Result GetHostName(std::string *p_out) {
|
|
std::array<char, HOST_NAME_MAX> buf;
|
|
if (gethostname(&buf[0], HOST_NAME_MAX) != 0) {
|
|
return system::FailWithCode("Failed to get host name.");
|
|
}
|
|
*p_out = buf.data();
|
|
return Success();
|
|
}
|
|
} // namespace xgboost::collective
|