enable ROCm on latest XGBoost
This commit is contained in:
41
tests/cpp/collective/net_test.h
Normal file
41
tests/cpp/collective/net_test.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/socket.h>
|
||||
|
||||
#include <fstream> // ifstream
|
||||
|
||||
#include "../helpers.h" // for FileExists
|
||||
|
||||
namespace xgboost::collective {
|
||||
class SocketTest : public ::testing::Test {
|
||||
protected:
|
||||
std::string skip_msg_{"Skipping IPv6 test"};
|
||||
|
||||
bool SkipTest() {
|
||||
std::string path{"/sys/module/ipv6/parameters/disable"};
|
||||
if (FileExists(path)) {
|
||||
std::ifstream fin(path);
|
||||
if (!fin) {
|
||||
return true;
|
||||
}
|
||||
std::string s_value;
|
||||
fin >> s_value;
|
||||
auto value = std::stoi(s_value);
|
||||
if (value != 0) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override { system::SocketStartup(); }
|
||||
void TearDown() override { system::SocketFinalize(); }
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
117
tests/cpp/collective/test_allgather.cc
Normal file
117
tests/cpp/collective/test_allgather.cc
Normal file
@@ -0,0 +1,117 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h> // for ASSERT_EQ
|
||||
#include <xgboost/span.h> // for Span, oper...
|
||||
|
||||
#include <algorithm> // for min
|
||||
#include <chrono> // for seconds
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <numeric> // for iota
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/allgather.h" // for RingAllgather
|
||||
#include "../../../src/collective/comm.h" // for RabitComm
|
||||
#include "gtest/gtest.h" // for AssertionR...
|
||||
#include "test_worker.h" // for TestDistri...
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class AllgatherTest : public TrackerTest {};
|
||||
|
||||
class Worker : public WorkerForTest {
|
||||
public:
|
||||
using WorkerForTest::WorkerForTest;
|
||||
|
||||
void Run() {
|
||||
{
|
||||
// basic test
|
||||
std::vector<std::int32_t> data(comm_.World(), 0);
|
||||
data[comm_.Rank()] = comm_.Rank();
|
||||
|
||||
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
ASSERT_EQ(data[r], r);
|
||||
}
|
||||
}
|
||||
{
|
||||
// test for limited socket buffer
|
||||
this->LimitSockBuf(4096);
|
||||
|
||||
std::size_t n = 8192; // n_bytes = 8192 * sizeof(int)
|
||||
std::vector<std::int32_t> data(comm_.World() * n, 0);
|
||||
auto s_data = common::Span{data.data(), data.size()};
|
||||
auto seg = s_data.subspan(comm_.Rank() * n, n);
|
||||
std::iota(seg.begin(), seg.end(), comm_.Rank());
|
||||
|
||||
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
auto seg = s_data.subspan(r * n, n);
|
||||
for (std::int32_t i = 0; i < static_cast<std::int32_t>(seg.size()); ++i) {
|
||||
auto v = seg[i];
|
||||
ASSERT_EQ(v, r + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestV() {
|
||||
{
|
||||
// basic test
|
||||
std::int32_t n{comm_.Rank()};
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
ASSERT_EQ(result[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// V test
|
||||
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
|
||||
std::int32_t k{0};
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
|
||||
if (comm_.Rank() == 0) {
|
||||
for (auto v : seg) {
|
||||
ASSERT_EQ(v, r);
|
||||
}
|
||||
k += seg.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllgatherTest, Basic) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.Run();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTest, V) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.TestV();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
72
tests/cpp/collective/test_allreduce.cc
Normal file
72
tests/cpp/collective/test_allreduce.cc
Normal file
@@ -0,0 +1,72 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/collective/allreduce.h"
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
namespace {
|
||||
class AllreduceWorker : public WorkerForTest {
|
||||
public:
|
||||
using WorkerForTest::WorkerForTest;
|
||||
|
||||
void Basic() {
|
||||
{
|
||||
std::vector<double> data(13, 0.0);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
|
||||
}
|
||||
{
|
||||
std::vector<double> data(1, 1.0);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
|
||||
}
|
||||
}
|
||||
|
||||
void Acc() {
|
||||
std::vector<double> data(314, 1.5);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||
auto v = data[i];
|
||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AllreduceTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllreduceTest, Basic) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
AllreduceWorker worker{host, port, timeout, n_workers, r};
|
||||
worker.Basic();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllreduceTest, Sum) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
AllreduceWorker worker{host, port, timeout, n_workers, r};
|
||||
worker.Acc();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
51
tests/cpp/collective/test_broadcast.cc
Normal file
51
tests/cpp/collective/test_broadcast.cc
Normal file
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/socket.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/broadcast.h" // for Broadcast
|
||||
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class Worker : public WorkerForTest {
|
||||
public:
|
||||
using WorkerForTest::WorkerForTest;
|
||||
|
||||
void Run() {
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
// basic test
|
||||
std::vector<std::int32_t> data(1, comm_.Rank());
|
||||
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(data[0], r);
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
std::vector<std::int32_t> data(1 << 16, comm_.Rank());
|
||||
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(data[0], r);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class BroadcastTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(BroadcastTest, Basic) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.Run();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
47
tests/cpp/collective/test_comm.cc
Normal file
47
tests/cpp/collective/test_comm.cc
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "test_worker.h"
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class CommTest : public TrackerTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(CommTest, Channel) {
|
||||
auto n_workers = 4;
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
WorkerForTest worker{host, port, timeout, n_workers, i};
|
||||
if (i % 2 == 0) {
|
||||
auto p_chan = worker.Comm().Chan(i + 1);
|
||||
p_chan->SendAll(
|
||||
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
} else {
|
||||
auto p_chan = worker.Comm().Chan(i - 1);
|
||||
std::int32_t r{-1};
|
||||
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(r, i - 1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -29,6 +29,11 @@ class InMemoryCommunicatorTest : public ::testing::Test {
|
||||
VerifyAllgather(comm, rank);
|
||||
}
|
||||
|
||||
static void AllgatherV(int rank) {
|
||||
InMemoryCommunicator comm{kWorldSize, rank};
|
||||
VerifyAllgatherV(comm, rank);
|
||||
}
|
||||
|
||||
static void AllreduceMax(int rank) {
|
||||
InMemoryCommunicator comm{kWorldSize, rank};
|
||||
VerifyAllreduceMax(comm, rank);
|
||||
@@ -80,14 +85,19 @@ class InMemoryCommunicatorTest : public ::testing::Test {
|
||||
|
||||
protected:
|
||||
static void VerifyAllgather(InMemoryCommunicator &comm, int rank) {
|
||||
char buffer[kWorldSize] = {'a', 'b', 'c'};
|
||||
buffer[rank] = '0' + rank;
|
||||
comm.AllGather(buffer, kWorldSize);
|
||||
std::string input{static_cast<char>('0' + rank)};
|
||||
auto output = comm.AllGather(input);
|
||||
for (auto i = 0; i < kWorldSize; i++) {
|
||||
EXPECT_EQ(buffer[i], '0' + i);
|
||||
EXPECT_EQ(output[i], static_cast<char>('0' + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void VerifyAllgatherV(InMemoryCommunicator &comm, int rank) {
|
||||
std::vector<std::string_view> inputs{"a", "bb", "ccc"};
|
||||
auto output = comm.AllGatherV(inputs[rank]);
|
||||
EXPECT_EQ(output, "abbccc");
|
||||
}
|
||||
|
||||
static void VerifyAllreduceMax(InMemoryCommunicator &comm, int rank) {
|
||||
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
|
||||
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax);
|
||||
@@ -205,6 +215,8 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
|
||||
|
||||
TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); }
|
||||
|
||||
TEST_F(InMemoryCommunicatorTest, AllgatherV) { Verify(&AllgatherV); }
|
||||
|
||||
TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }
|
||||
|
||||
TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }
|
||||
|
||||
81
tests/cpp/collective/test_loop.cc
Normal file
81
tests/cpp/collective/test_loop.cc
Normal file
@@ -0,0 +1,81 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h> // for ASSERT_TRUE, ASSERT_EQ
|
||||
#include <xgboost/collective/socket.h> // for TCPSocket, Connect, SocketFinalize, SocketStartup
|
||||
#include <xgboost/string_view.h> // for StringView
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int8_t
|
||||
#include <memory> // for make_shared, shared_ptr
|
||||
#include <system_error> // for make_error_code, errc
|
||||
#include <utility> // for pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/loop.h" // for Loop
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class LoopTest : public ::testing::Test {
|
||||
protected:
|
||||
std::pair<TCPSocket, TCPSocket> pair_;
|
||||
std::shared_ptr<Loop> loop_;
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
system::SocketStartup();
|
||||
std::chrono::seconds timeout{1};
|
||||
|
||||
auto domain = SockDomain::kV4;
|
||||
pair_.first = TCPSocket::Create(domain);
|
||||
auto port = pair_.first.BindHost();
|
||||
pair_.first.Listen();
|
||||
|
||||
auto const& addr = SockAddrV4::Loopback().Addr();
|
||||
auto rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
rc = pair_.second.NonBlocking(true);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
|
||||
pair_.first = pair_.first.Accept();
|
||||
rc = pair_.first.NonBlocking(true);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
|
||||
loop_ = std::make_shared<Loop>(timeout);
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
pair_ = decltype(pair_){};
|
||||
system::SocketFinalize();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(LoopTest, Timeout) {
|
||||
std::vector<std::int8_t> data(1);
|
||||
Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &pair_.second, 0};
|
||||
loop_->Submit(op);
|
||||
auto rc = loop_->Block();
|
||||
ASSERT_FALSE(rc.OK());
|
||||
ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)) << rc.Report();
|
||||
}
|
||||
|
||||
TEST_F(LoopTest, Op) {
|
||||
TCPSocket& send = pair_.first;
|
||||
TCPSocket& recv = pair_.second;
|
||||
|
||||
std::vector<std::int8_t> wbuf(1, 1);
|
||||
std::vector<std::int8_t> rbuf(1, 0);
|
||||
|
||||
Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0};
|
||||
Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0};
|
||||
|
||||
loop_->Submit(wop);
|
||||
loop_->Submit(rop);
|
||||
|
||||
auto rc = loop_->Block();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
|
||||
ASSERT_EQ(rbuf[0], wbuf[0]);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -38,7 +38,7 @@ void VerifyAllReduceBitwiseAND() {
|
||||
auto const rank = collective::GetRank();
|
||||
std::bitset<64> original{};
|
||||
original[rank] = true;
|
||||
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, rank);
|
||||
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, DeviceOrd::CUDA(rank));
|
||||
collective::AllReduce<collective::Operation::kBitwiseAND>(rank, buffer.DevicePointer(), 1);
|
||||
collective::Synchronize(rank);
|
||||
EXPECT_EQ(buffer.HostVector()[0], 0ULL);
|
||||
@@ -60,7 +60,7 @@ void VerifyAllReduceBitwiseOR() {
|
||||
auto const rank = collective::GetRank();
|
||||
std::bitset<64> original{};
|
||||
original[rank] = true;
|
||||
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, rank);
|
||||
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, DeviceOrd::CUDA(rank));
|
||||
collective::AllReduce<collective::Operation::kBitwiseOR>(rank, buffer.DevicePointer(), 1);
|
||||
collective::Synchronize(rank);
|
||||
EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1);
|
||||
@@ -82,7 +82,7 @@ void VerifyAllReduceBitwiseXOR() {
|
||||
auto const rank = collective::GetRank();
|
||||
std::bitset<64> original{~0ULL};
|
||||
original[rank] = false;
|
||||
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, rank);
|
||||
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, DeviceOrd::CUDA(rank));
|
||||
collective::AllReduce<collective::Operation::kBitwiseXOR>(rank, buffer.DevicePointer(), 1);
|
||||
collective::Synchronize(rank);
|
||||
EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1);
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
* Copyright 2022-2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/socket.h>
|
||||
|
||||
#include <cerrno> // EADDRNOTAVAIL
|
||||
#include <fstream> // ifstream
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "test_worker.h" // for SocketTest
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST(Socket, Basic) {
|
||||
system::SocketStartup();
|
||||
|
||||
TEST_F(SocketTest, Basic) {
|
||||
SockAddress addr{SockAddrV6::Loopback()};
|
||||
ASSERT_TRUE(addr.IsV6());
|
||||
addr = SockAddress{SockAddrV4::Loopback()};
|
||||
@@ -54,23 +51,27 @@ TEST(Socket, Basic) {
|
||||
|
||||
run_test(SockDomain::kV4);
|
||||
|
||||
std::string path{"/sys/module/ipv6/parameters/disable"};
|
||||
if (FileExists(path)) {
|
||||
std::ifstream fin(path);
|
||||
if (!fin) {
|
||||
GTEST_SKIP_(msg.c_str());
|
||||
}
|
||||
std::string s_value;
|
||||
fin >> s_value;
|
||||
auto value = std::stoi(s_value);
|
||||
if (value != 0) {
|
||||
GTEST_SKIP_(msg.c_str());
|
||||
}
|
||||
} else {
|
||||
GTEST_SKIP_(msg.c_str());
|
||||
if (SkipTest()) {
|
||||
GTEST_SKIP_(skip_msg_.c_str());
|
||||
}
|
||||
run_test(SockDomain::kV6);
|
||||
}
|
||||
|
||||
system::SocketFinalize();
|
||||
TEST_F(SocketTest, Bind) {
|
||||
auto run = [](SockDomain domain) {
|
||||
auto any =
|
||||
domain == SockDomain::kV4 ? SockAddrV4::InaddrAny().Addr() : SockAddrV6::InaddrAny().Addr();
|
||||
auto sock = TCPSocket::Create(domain);
|
||||
std::int32_t port{0};
|
||||
auto rc = sock.Bind(any, &port);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_NE(port, 0);
|
||||
};
|
||||
|
||||
run(SockDomain::kV4);
|
||||
if (SkipTest()) {
|
||||
GTEST_SKIP_(skip_msg_.c_str());
|
||||
}
|
||||
run(SockDomain::kV6);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
67
tests/cpp/collective/test_tracker.cc
Normal file
67
tests/cpp/collective/test_tracker.cc
Normal file
@@ -0,0 +1,67 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "test_worker.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class PrintWorker : public WorkerForTest {
|
||||
public:
|
||||
using WorkerForTest::WorkerForTest;
|
||||
|
||||
void Print() {
|
||||
auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(TrackerTest, Bootstrap) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
|
||||
}
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
TEST_F(TrackerTest, Print) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
PrintWorker worker{host, port, timeout, n_workers, i};
|
||||
worker.Print();
|
||||
});
|
||||
}
|
||||
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
|
||||
} // namespace xgboost::collective
|
||||
114
tests/cpp/collective/test_worker.h
Normal file
114
tests/cpp/collective/test_worker.h
Normal file
@@ -0,0 +1,114 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "../helpers.h" // for FileExists
|
||||
|
||||
namespace xgboost::collective {
|
||||
class WorkerForTest {
|
||||
std::string tracker_host_;
|
||||
std::int32_t tracker_port_;
|
||||
std::int32_t world_size_;
|
||||
|
||||
protected:
|
||||
std::int32_t retry_{1};
|
||||
std::string task_id_;
|
||||
RabitComm comm_;
|
||||
|
||||
public:
|
||||
WorkerForTest(std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t world, std::int32_t rank)
|
||||
: tracker_host_{std::move(host)},
|
||||
tracker_port_{port},
|
||||
world_size_{world},
|
||||
task_id_{"t:" + std::to_string(rank)},
|
||||
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_} {
|
||||
CHECK_EQ(world_size_, comm_.World());
|
||||
}
|
||||
virtual ~WorkerForTest() = default;
|
||||
auto& Comm() { return comm_; }
|
||||
|
||||
void LimitSockBuf(std::int32_t n_bytes) {
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
if (i != comm_.Rank()) {
|
||||
ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking());
|
||||
ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SocketTest : public ::testing::Test {
|
||||
protected:
|
||||
std::string skip_msg_{"Skipping IPv6 test"};
|
||||
|
||||
bool SkipTest() {
|
||||
std::string path{"/sys/module/ipv6/parameters/disable"};
|
||||
if (FileExists(path)) {
|
||||
std::ifstream fin(path);
|
||||
if (!fin) {
|
||||
return true;
|
||||
}
|
||||
std::string s_value;
|
||||
fin >> s_value;
|
||||
auto value = std::stoi(s_value);
|
||||
if (value != 0) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override { system::SocketStartup(); }
|
||||
void TearDown() override { system::SocketFinalize(); }
|
||||
};
|
||||
|
||||
class TrackerTest : public SocketTest {
|
||||
public:
|
||||
std::int32_t n_workers{2};
|
||||
std::chrono::seconds timeout{1};
|
||||
std::string host;
|
||||
|
||||
void SetUp() override {
|
||||
SocketTest::SetUp();
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
||||
std::chrono::seconds timeout{1};
|
||||
|
||||
std::string host;
|
||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { worker_fn(host, port, timeout, i); });
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
Reference in New Issue
Block a user