[Coll] Implement get host address in libxgboost. (#9644)

- Port `xgboost.tracker.get_host_ip` in C++.
This commit is contained in:
Jiaming Yuan
2023-10-10 10:01:14 +08:00
committed by GitHub
parent 680d53db43
commit b14e535e78
6 changed files with 199 additions and 31 deletions

View File

@@ -0,0 +1,41 @@
/**
* Copyright 2022-2023, XGBoost Contributors
*/
#pragma once
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
#include <fstream> // ifstream
#include "../helpers.h" // for FileExists
namespace xgboost::collective {
class SocketTest : public ::testing::Test {
protected:
std::string skip_msg_{"Skipping IPv6 test"};
bool SkipTest() {
std::string path{"/sys/module/ipv6/parameters/disable"};
if (FileExists(path)) {
std::ifstream fin(path);
if (!fin) {
return true;
}
std::string s_value;
fin >> s_value;
auto value = std::stoi(s_value);
if (value != 0) {
return true;
}
} else {
return true;
}
return false;
}
protected:
void SetUp() override { system::SocketStartup(); }
void TearDown() override { system::SocketFinalize(); }
};
} // namespace xgboost::collective

View File

@@ -1,19 +1,16 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
#include <cerrno> // EADDRNOTAVAIL
#include <fstream> // ifstream
#include <system_error> // std::error_code, std::system_category
#include "../helpers.h"
#include "net_test.h" // for SocketTest
namespace xgboost::collective {
TEST(Socket, Basic) {
system::SocketStartup();
TEST_F(SocketTest, Basic) {
SockAddress addr{SockAddrV6::Loopback()};
ASSERT_TRUE(addr.IsV6());
addr = SockAddress{SockAddrV4::Loopback()};
@@ -54,34 +51,27 @@ TEST(Socket, Basic) {
run_test(SockDomain::kV4);
std::string path{"/sys/module/ipv6/parameters/disable"};
if (FileExists(path)) {
std::ifstream fin(path);
if (!fin) {
GTEST_SKIP_(msg.c_str());
}
std::string s_value;
fin >> s_value;
auto value = std::stoi(s_value);
if (value != 0) {
GTEST_SKIP_(msg.c_str());
}
} else {
GTEST_SKIP_(msg.c_str());
if (SkipTest()) {
GTEST_SKIP_(skip_msg_.c_str());
}
run_test(SockDomain::kV6);
system::SocketFinalize();
}
TEST(Socket, Bind) {
system::SocketStartup();
auto any = SockAddrV4::InaddrAny().Addr();
auto sock = TCPSocket::Create(SockDomain::kV4);
std::int32_t port{0};
auto rc = sock.Bind(any, &port);
ASSERT_TRUE(rc.OK());
ASSERT_NE(port, 0);
system::SocketFinalize();
TEST_F(SocketTest, Bind) {
auto run = [](SockDomain domain) {
auto any =
domain == SockDomain::kV4 ? SockAddrV4::InaddrAny().Addr() : SockAddrV6::InaddrAny().Addr();
auto sock = TCPSocket::Create(domain);
std::int32_t port{0};
auto rc = sock.Bind(any, &port);
ASSERT_TRUE(rc.OK());
ASSERT_NE(port, 0);
};
run(SockDomain::kV4);
if (SkipTest()) {
GTEST_SKIP_(skip_msg_.c_str());
}
run(SockDomain::kV6);
}
} // namespace xgboost::collective

View File

@@ -0,0 +1,18 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "../../../src/collective/tracker.h" // for GetHostAddress
#include "net_test.h" // for SocketTest
namespace xgboost::collective {
namespace {
class TrackerTest : public SocketTest {};
} // namespace
TEST_F(TrackerTest, GetHostAddress) {
std::string host;
auto rc = GetHostAddress(&host);
ASSERT_TRUE(rc.OK());
ASSERT_TRUE(host.find("127.") == std::string::npos);
}
} // namespace xgboost::collective