[Coll] Implement get host address in libxgboost. (#9644)
- Port `xgboost.tracker.get_host_ip` in C++.
This commit is contained in:
parent
680d53db43
commit
b14e535e78
@ -658,6 +658,34 @@ class TCPSocket {
|
|||||||
* @brief Get the local host name.
|
* @brief Get the local host name.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] Result GetHostName(std::string *p_out);
|
[[nodiscard]] Result GetHostName(std::string *p_out);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief inet_ntop
|
||||||
|
*/
|
||||||
|
template <typename H>
|
||||||
|
Result INetNToP(H const &host, std::string *p_out) {
|
||||||
|
std::string &ip = *p_out;
|
||||||
|
switch (host->h_addrtype) {
|
||||||
|
case AF_INET: {
|
||||||
|
auto addr = reinterpret_cast<struct in_addr *>(host->h_addr_list[0]);
|
||||||
|
char str[INET_ADDRSTRLEN];
|
||||||
|
inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
|
||||||
|
ip = str;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case AF_INET6: {
|
||||||
|
auto addr = reinterpret_cast<struct in6_addr *>(host->h_addr_list[0]);
|
||||||
|
char str[INET6_ADDRSTRLEN];
|
||||||
|
inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
|
||||||
|
ip = str;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
return Fail("Invalid address type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
} // namespace collective
|
} // namespace collective
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
76
src/collective/tracker.cc
Normal file
76
src/collective/tracker.cc
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#if defined(__unix__) || defined(__APPLE__)
|
||||||
|
#include <netdb.h> // gethostbyname
|
||||||
|
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
|
||||||
|
#endif // defined(__unix__) || defined(__APPLE__)
|
||||||
|
|
||||||
|
#if !defined(NOMINMAX) && defined(_WIN32)
|
||||||
|
#define NOMINMAX
|
||||||
|
#endif // !defined(NOMINMAX)
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#include <winsock2.h>
|
||||||
|
#include <ws2tcpip.h>
|
||||||
|
#endif // defined(_WIN32)
|
||||||
|
|
||||||
|
#include <string> // for string
|
||||||
|
|
||||||
|
#include "xgboost/collective/result.h" // for Result, Fail, Success
|
||||||
|
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
||||||
|
auto rc = GetHostName(out);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
auto host = gethostbyname(out->c_str());
|
||||||
|
|
||||||
|
// get ip address from host
|
||||||
|
std::string ip;
|
||||||
|
rc = INetNToP(host, &ip);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(ip.size() >= 4 && ip.substr(0, 4) == "127.")) {
|
||||||
|
// return if this is a public IP address.
|
||||||
|
// not entirely accurate, we have other reserved IPs
|
||||||
|
*out = ip;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an UDP socket to prob the public IP address, it's fine even if it's
|
||||||
|
// unreachable.
|
||||||
|
auto sock = socket(AF_INET, SOCK_DGRAM, 0);
|
||||||
|
if (sock == -1) {
|
||||||
|
return Fail("Failed to create socket.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto paddr = MakeSockAddress(StringView{"10.255.255.255"}, 1);
|
||||||
|
sockaddr const* addr_handle = reinterpret_cast<const sockaddr*>(&paddr.V4().Handle());
|
||||||
|
socklen_t addr_len{sizeof(paddr.V4().Handle())};
|
||||||
|
auto err = connect(sock, addr_handle, addr_len);
|
||||||
|
if (err != 0) {
|
||||||
|
return system::FailWithCode("Failed to find IP address.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the IP address from socket desrciptor
|
||||||
|
struct sockaddr_in addr;
|
||||||
|
socklen_t len = sizeof(addr);
|
||||||
|
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &len) == -1) {
|
||||||
|
return Fail("Failed to get sock name.");
|
||||||
|
}
|
||||||
|
ip = inet_ntoa(addr.sin_addr);
|
||||||
|
|
||||||
|
err = system::CloseSocket(sock);
|
||||||
|
if (err != 0) {
|
||||||
|
return system::FailWithCode("Failed to close socket.");
|
||||||
|
}
|
||||||
|
|
||||||
|
*out = ip;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
15
src/collective/tracker.h
Normal file
15
src/collective/tracker.h
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <string> // for string
|
||||||
|
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
// Prob the public IP address of the host, need a better method.
|
||||||
|
//
|
||||||
|
// This is directly translated from the previous Python implementation, we should find a
|
||||||
|
// more riguous approach, can use some expertise in network programming.
|
||||||
|
[[nodiscard]] Result GetHostAddress(std::string* out);
|
||||||
|
} // namespace xgboost::collective
|
||||||
41
tests/cpp/collective/net_test.h
Normal file
41
tests/cpp/collective/net_test.h
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022-2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/collective/socket.h>
|
||||||
|
|
||||||
|
#include <fstream> // ifstream
|
||||||
|
|
||||||
|
#include "../helpers.h" // for FileExists
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
class SocketTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
std::string skip_msg_{"Skipping IPv6 test"};
|
||||||
|
|
||||||
|
bool SkipTest() {
|
||||||
|
std::string path{"/sys/module/ipv6/parameters/disable"};
|
||||||
|
if (FileExists(path)) {
|
||||||
|
std::ifstream fin(path);
|
||||||
|
if (!fin) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::string s_value;
|
||||||
|
fin >> s_value;
|
||||||
|
auto value = std::stoi(s_value);
|
||||||
|
if (value != 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override { system::SocketStartup(); }
|
||||||
|
void TearDown() override { system::SocketFinalize(); }
|
||||||
|
};
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -1,19 +1,16 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2022-2023 by XGBoost Contributors
|
* Copyright 2022-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/collective/socket.h>
|
#include <xgboost/collective/socket.h>
|
||||||
|
|
||||||
#include <cerrno> // EADDRNOTAVAIL
|
#include <cerrno> // EADDRNOTAVAIL
|
||||||
#include <fstream> // ifstream
|
|
||||||
#include <system_error> // std::error_code, std::system_category
|
#include <system_error> // std::error_code, std::system_category
|
||||||
|
|
||||||
#include "../helpers.h"
|
#include "net_test.h" // for SocketTest
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
TEST(Socket, Basic) {
|
TEST_F(SocketTest, Basic) {
|
||||||
system::SocketStartup();
|
|
||||||
|
|
||||||
SockAddress addr{SockAddrV6::Loopback()};
|
SockAddress addr{SockAddrV6::Loopback()};
|
||||||
ASSERT_TRUE(addr.IsV6());
|
ASSERT_TRUE(addr.IsV6());
|
||||||
addr = SockAddress{SockAddrV4::Loopback()};
|
addr = SockAddress{SockAddrV4::Loopback()};
|
||||||
@ -54,34 +51,27 @@ TEST(Socket, Basic) {
|
|||||||
|
|
||||||
run_test(SockDomain::kV4);
|
run_test(SockDomain::kV4);
|
||||||
|
|
||||||
std::string path{"/sys/module/ipv6/parameters/disable"};
|
if (SkipTest()) {
|
||||||
if (FileExists(path)) {
|
GTEST_SKIP_(skip_msg_.c_str());
|
||||||
std::ifstream fin(path);
|
|
||||||
if (!fin) {
|
|
||||||
GTEST_SKIP_(msg.c_str());
|
|
||||||
}
|
|
||||||
std::string s_value;
|
|
||||||
fin >> s_value;
|
|
||||||
auto value = std::stoi(s_value);
|
|
||||||
if (value != 0) {
|
|
||||||
GTEST_SKIP_(msg.c_str());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
GTEST_SKIP_(msg.c_str());
|
|
||||||
}
|
}
|
||||||
run_test(SockDomain::kV6);
|
run_test(SockDomain::kV6);
|
||||||
|
|
||||||
system::SocketFinalize();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Socket, Bind) {
|
TEST_F(SocketTest, Bind) {
|
||||||
system::SocketStartup();
|
auto run = [](SockDomain domain) {
|
||||||
auto any = SockAddrV4::InaddrAny().Addr();
|
auto any =
|
||||||
auto sock = TCPSocket::Create(SockDomain::kV4);
|
domain == SockDomain::kV4 ? SockAddrV4::InaddrAny().Addr() : SockAddrV6::InaddrAny().Addr();
|
||||||
std::int32_t port{0};
|
auto sock = TCPSocket::Create(domain);
|
||||||
auto rc = sock.Bind(any, &port);
|
std::int32_t port{0};
|
||||||
ASSERT_TRUE(rc.OK());
|
auto rc = sock.Bind(any, &port);
|
||||||
ASSERT_NE(port, 0);
|
ASSERT_TRUE(rc.OK());
|
||||||
system::SocketFinalize();
|
ASSERT_NE(port, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
run(SockDomain::kV4);
|
||||||
|
if (SkipTest()) {
|
||||||
|
GTEST_SKIP_(skip_msg_.c_str());
|
||||||
|
}
|
||||||
|
run(SockDomain::kV6);
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
18
tests/cpp/collective/test_tracker.cc
Normal file
18
tests/cpp/collective/test_tracker.cc
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||||
|
#include "net_test.h" // for SocketTest
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace {
|
||||||
|
class TrackerTest : public SocketTest {};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_F(TrackerTest, GetHostAddress) {
|
||||||
|
std::string host;
|
||||||
|
auto rc = GetHostAddress(&host);
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
|
ASSERT_TRUE(host.find("127.") == std::string::npos);
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
Loading…
x
Reference in New Issue
Block a user