From b14e535e788b4d1c09de886d8a41cd5b316c6bc2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 10 Oct 2023 10:01:14 +0800 Subject: [PATCH] [Coll] Implement get host address in libxgboost. (#9644) - Port `xgboost.tracker.get_host_ip` in C++. --- include/xgboost/collective/socket.h | 28 ++++++++++ src/collective/tracker.cc | 76 ++++++++++++++++++++++++++++ src/collective/tracker.h | 15 ++++++ tests/cpp/collective/net_test.h | 41 +++++++++++++++ tests/cpp/collective/test_socket.cc | 52 ++++++++----------- tests/cpp/collective/test_tracker.cc | 18 +++++++ 6 files changed, 199 insertions(+), 31 deletions(-) create mode 100644 src/collective/tracker.cc create mode 100644 src/collective/tracker.h create mode 100644 tests/cpp/collective/net_test.h create mode 100644 tests/cpp/collective/test_tracker.cc diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index f36cdccb2..a16dd05c0 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -658,6 +658,34 @@ class TCPSocket { * @brief Get the local host name. */ [[nodiscard]] Result GetHostName(std::string *p_out); + +/** + * @brief inet_ntop + */ +template +Result INetNToP(H const &host, std::string *p_out) { + std::string &ip = *p_out; + switch (host->h_addrtype) { + case AF_INET: { + auto addr = reinterpret_cast(host->h_addr_list[0]); + char str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN); + ip = str; + break; + } + case AF_INET6: { + auto addr = reinterpret_cast(host->h_addr_list[0]); + char str[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN); + ip = str; + break; + } + default: { + return Fail("Invalid address type."); + } + } + return Success(); +} } // namespace collective } // namespace xgboost diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc new file mode 100644 index 000000000..598b41ddd --- /dev/null +++ b/src/collective/tracker.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#if defined(__unix__) || defined(__APPLE__) +#include // gethostbyname +#include // socket, AF_INET6, AF_INET, connect, getsockname +#endif // defined(__unix__) || defined(__APPLE__) + +#if !defined(NOMINMAX) && defined(_WIN32) +#define NOMINMAX +#endif // !defined(NOMINMAX) + +#if defined(_WIN32) +#include +#include +#endif // defined(_WIN32) + +#include // for string + +#include "xgboost/collective/result.h" // for Result, Fail, Success +#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ... + +namespace xgboost::collective { +[[nodiscard]] Result GetHostAddress(std::string* out) { + auto rc = GetHostName(out); + if (!rc.OK()) { + return rc; + } + auto host = gethostbyname(out->c_str()); + + // get ip address from host + std::string ip; + rc = INetNToP(host, &ip); + if (!rc.OK()) { + return rc; + } + + if (!(ip.size() >= 4 && ip.substr(0, 4) == "127.")) { + // return if this is a public IP address. + // not entirely accurate, we have other reserved IPs + *out = ip; + return Success(); + } + + // Create an UDP socket to prob the public IP address, it's fine even if it's + // unreachable. + auto sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock == -1) { + return Fail("Failed to create socket."); + } + + auto paddr = MakeSockAddress(StringView{"10.255.255.255"}, 1); + sockaddr const* addr_handle = reinterpret_cast(&paddr.V4().Handle()); + socklen_t addr_len{sizeof(paddr.V4().Handle())}; + auto err = connect(sock, addr_handle, addr_len); + if (err != 0) { + return system::FailWithCode("Failed to find IP address."); + } + + // get the IP address from socket desrciptor + struct sockaddr_in addr; + socklen_t len = sizeof(addr); + if (getsockname(sock, reinterpret_cast(&addr), &len) == -1) { + return Fail("Failed to get sock name."); + } + ip = inet_ntoa(addr.sin_addr); + + err = system::CloseSocket(sock); + if (err != 0) { + return system::FailWithCode("Failed to close socket."); + } + + *out = ip; + return Success(); +} +} // namespace xgboost::collective diff --git a/src/collective/tracker.h b/src/collective/tracker.h new file mode 100644 index 000000000..ec52f6a62 --- /dev/null +++ b/src/collective/tracker.h @@ -0,0 +1,15 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for string + +#include "xgboost/collective/result.h" // for Result + +namespace xgboost::collective { +// Prob the public IP address of the host, need a better method. +// +// This is directly translated from the previous Python implementation, we should find a +// more riguous approach, can use some expertise in network programming. +[[nodiscard]] Result GetHostAddress(std::string* out); +} // namespace xgboost::collective diff --git a/tests/cpp/collective/net_test.h b/tests/cpp/collective/net_test.h new file mode 100644 index 000000000..ed15ed256 --- /dev/null +++ b/tests/cpp/collective/net_test.h @@ -0,0 +1,41 @@ +/** + * Copyright 2022-2023, XGBoost Contributors + */ +#pragma once + +#include +#include + +#include // ifstream + +#include "../helpers.h" // for FileExists + +namespace xgboost::collective { +class SocketTest : public ::testing::Test { + protected: + std::string skip_msg_{"Skipping IPv6 test"}; + + bool SkipTest() { + std::string path{"/sys/module/ipv6/parameters/disable"}; + if (FileExists(path)) { + std::ifstream fin(path); + if (!fin) { + return true; + } + std::string s_value; + fin >> s_value; + auto value = std::stoi(s_value); + if (value != 0) { + return true; + } + } else { + return true; + } + return false; + } + + protected: + void SetUp() override { system::SocketStartup(); } + void TearDown() override { system::SocketFinalize(); } +}; +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc index 07a7f52d0..7802acda8 100644 --- a/tests/cpp/collective/test_socket.cc +++ b/tests/cpp/collective/test_socket.cc @@ -1,19 +1,16 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2023, XGBoost Contributors */ #include #include #include // EADDRNOTAVAIL -#include // ifstream #include // std::error_code, std::system_category -#include "../helpers.h" +#include "net_test.h" // for SocketTest namespace xgboost::collective { -TEST(Socket, Basic) { - system::SocketStartup(); - +TEST_F(SocketTest, Basic) { SockAddress addr{SockAddrV6::Loopback()}; ASSERT_TRUE(addr.IsV6()); addr = SockAddress{SockAddrV4::Loopback()}; @@ -54,34 +51,27 @@ TEST(Socket, Basic) { run_test(SockDomain::kV4); - std::string path{"/sys/module/ipv6/parameters/disable"}; - if (FileExists(path)) { - std::ifstream fin(path); - if (!fin) { - GTEST_SKIP_(msg.c_str()); - } - std::string s_value; - fin >> s_value; - auto value = std::stoi(s_value); - if (value != 0) { - GTEST_SKIP_(msg.c_str()); - } - } else { - GTEST_SKIP_(msg.c_str()); + if (SkipTest()) { + GTEST_SKIP_(skip_msg_.c_str()); } run_test(SockDomain::kV6); - - system::SocketFinalize(); } -TEST(Socket, Bind) { - system::SocketStartup(); - auto any = SockAddrV4::InaddrAny().Addr(); - auto sock = TCPSocket::Create(SockDomain::kV4); - std::int32_t port{0}; - auto rc = sock.Bind(any, &port); - ASSERT_TRUE(rc.OK()); - ASSERT_NE(port, 0); - system::SocketFinalize(); +TEST_F(SocketTest, Bind) { + auto run = [](SockDomain domain) { + auto any = + domain == SockDomain::kV4 ? SockAddrV4::InaddrAny().Addr() : SockAddrV6::InaddrAny().Addr(); + auto sock = TCPSocket::Create(domain); + std::int32_t port{0}; + auto rc = sock.Bind(any, &port); + ASSERT_TRUE(rc.OK()); + ASSERT_NE(port, 0); + }; + + run(SockDomain::kV4); + if (SkipTest()) { + GTEST_SKIP_(skip_msg_.c_str()); + } + run(SockDomain::kV6); } } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_tracker.cc b/tests/cpp/collective/test_tracker.cc new file mode 100644 index 000000000..0e60cfb68 --- /dev/null +++ b/tests/cpp/collective/test_tracker.cc @@ -0,0 +1,18 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include "../../../src/collective/tracker.h" // for GetHostAddress +#include "net_test.h" // for SocketTest + +namespace xgboost::collective { +namespace { +class TrackerTest : public SocketTest {}; +} // namespace + +TEST_F(TrackerTest, GetHostAddress) { + std::string host; + auto rc = GetHostAddress(&host); + ASSERT_TRUE(rc.OK()); + ASSERT_TRUE(host.find("127.") == std::string::npos); +} +} // namespace xgboost::collective