/** * Copyright 2023, XGBoost Contributors */ #include #include // for seconds #include // for int32_t #include // for string #include // for thread #include // for move #include // for vector #include "../../../src/collective/comm.h" #include "../../../src/collective/tracker.h" // for GetHostAddress #include "../helpers.h" // for FileExists namespace xgboost::collective { class WorkerForTest { std::string tracker_host_; std::int32_t tracker_port_; std::int32_t world_size_; protected: std::int32_t retry_{1}; std::string task_id_; RabitComm comm_; public: WorkerForTest(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t world, std::int32_t rank) : tracker_host_{std::move(host)}, tracker_port_{port}, world_size_{world}, task_id_{"t:" + std::to_string(rank)}, comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_} { CHECK_EQ(world_size_, comm_.World()); } virtual ~WorkerForTest() = default; auto& Comm() { return comm_; } void LimitSockBuf(std::int32_t n_bytes) { for (std::int32_t i = 0; i < comm_.World(); ++i) { if (i != comm_.Rank()) { ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking()); ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK()); } } } }; class SocketTest : public ::testing::Test { protected: std::string skip_msg_{"Skipping IPv6 test"}; bool SkipTest() { std::string path{"/sys/module/ipv6/parameters/disable"}; if (FileExists(path)) { std::ifstream fin(path); if (!fin) { return true; } std::string s_value; fin >> s_value; auto value = std::stoi(s_value); if (value != 0) { return true; } } else { return true; } return false; } protected: void SetUp() override { system::SocketStartup(); } void TearDown() override { system::SocketFinalize(); } }; class TrackerTest : public SocketTest { public: std::int32_t n_workers{2}; std::chrono::seconds timeout{1}; std::string host; void SetUp() override { SocketTest::SetUp(); auto rc = GetHostAddress(&host); ASSERT_TRUE(rc.OK()) << rc.Report(); } }; } // namespace xgboost::collective