Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -1,41 +0,0 @@
/**
* Copyright 2022-2023, XGBoost Contributors
*/
#pragma once
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
#include <fstream> // ifstream
#include "../helpers.h" // for FileExists
namespace xgboost::collective {
class SocketTest : public ::testing::Test {
protected:
std::string skip_msg_{"Skipping IPv6 test"};
bool SkipTest() {
std::string path{"/sys/module/ipv6/parameters/disable"};
if (FileExists(path)) {
std::ifstream fin(path);
if (!fin) {
return true;
}
std::string s_value;
fin >> s_value;
auto value = std::stoi(s_value);
if (value != 0) {
return true;
}
} else {
return true;
}
return false;
}
protected:
void SetUp() override { system::SocketStartup(); }
void TearDown() override { system::SocketFinalize(); }
};
} // namespace xgboost::collective

View File

@@ -175,4 +175,35 @@ TEST_F(AllgatherTest, VAlgo) {
worker.TestVAlgo();
});
}
TEST(VectorAllgatherV, Basic) {
std::int32_t n_workers{3};
TestDistributedGlobal(n_workers, []() {
auto n_workers = collective::GetWorldSize();
ASSERT_EQ(n_workers, 3);
auto rank = collective::GetRank();
// Construct input that has different length for each worker.
std::vector<std::vector<char>> inputs;
for (std::int32_t i = 0; i < rank + 1; ++i) {
std::vector<char> in;
for (std::int32_t j = 0; j < rank + 1; ++j) {
in.push_back(static_cast<char>(j));
}
inputs.emplace_back(std::move(in));
}
Context ctx;
auto outputs = VectorAllgatherV(&ctx, inputs);
ASSERT_EQ(outputs.size(), (1 + n_workers) * n_workers / 2);
auto const& res = outputs;
for (std::int32_t i = 0; i < n_workers; ++i) {
std::int32_t k = 0;
for (auto v : res[i]) {
ASSERT_EQ(v, k++);
}
}
});
}
} // namespace xgboost::collective

View File

@@ -39,6 +39,22 @@ class AllreduceWorker : public WorkerForTest {
}
}
void Restricted() {
this->LimitSockBuf(4096);
std::size_t n = 4096 * 4;
std::vector<std::int32_t> data(comm_.World() * n, 1);
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
for (auto v : data) {
ASSERT_EQ(v, comm_.World());
}
}
void Acc() {
std::vector<double> data(314, 1.5);
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
@@ -95,4 +111,45 @@ TEST_F(AllreduceTest, BitOr) {
worker.BitOr();
});
}
TEST_F(AllreduceTest, Restricted) {
std::int32_t n_workers = std::min(3u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.Restricted();
});
}
TEST(AllreduceGlobal, Basic) {
auto n_workers = 3;
TestDistributedGlobal(n_workers, [&]() {
std::vector<float> values(n_workers * 2, 0);
auto rank = GetRank();
auto s_values = common::Span{values.data(), values.size()};
auto self = s_values.subspan(rank * 2, 2);
for (auto& v : self) {
v = 1.0f;
}
Context ctx;
auto rc =
Allreduce(&ctx, linalg::MakeVec(s_values.data(), s_values.size()), collective::Op::kSum);
SafeColl(rc);
for (auto v : s_values) {
ASSERT_EQ(v, 1);
}
});
}
TEST(AllreduceGlobal, Small) {
// Test when the data is not large enougth to be divided by the number of workers
auto n_workers = 8;
TestDistributedGlobal(n_workers, [&]() {
std::uint64_t value{1};
Context ctx;
auto rc = Allreduce(&ctx, linalg::MakeVec(&value, 1), collective::Op::kSum);
SafeColl(rc);
ASSERT_EQ(value, n_workers);
});
}
} // namespace xgboost::collective

View File

@@ -1,63 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include "../../../src/collective/communicator.h"
namespace xgboost {
namespace collective {
TEST(CommunicatorFactory, TypeFromEnv) {
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv());
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "foo");
EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error);
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "rabit");
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv());
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "Federated");
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv());
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "In-Memory");
EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromEnv());
}
TEST(CommunicatorFactory, TypeFromArgs) {
Json config{JsonObject()};
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config));
config["xgboost_communicator"] = String("rabit");
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config));
config["xgboost_communicator"] = String("federated");
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
config["xgboost_communicator"] = String("in-memory");
EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config));
config["xgboost_communicator"] = String("foo");
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
}
TEST(CommunicatorFactory, TypeFromArgsUpperCase) {
Json config{JsonObject()};
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config));
config["XGBOOST_COMMUNICATOR"] = String("rabit");
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config));
config["XGBOOST_COMMUNICATOR"] = String("federated");
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
config["XGBOOST_COMMUNICATOR"] = String("in-memory");
EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config));
config["XGBOOST_COMMUNICATOR"] = String("foo");
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
}
} // namespace collective
} // namespace xgboost

View File

@@ -1,237 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <bitset>
#include <thread>
#include "../../../src/collective/in_memory_communicator.h"
namespace xgboost {
namespace collective {
class InMemoryCommunicatorTest : public ::testing::Test {
public:
static void Verify(void (*function)(int)) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(function, rank);
}
for (auto &thread : threads) {
thread.join();
}
}
static void Allgather(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllgather(comm, rank);
}
static void AllgatherV(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllgatherV(comm, rank);
}
static void AllreduceMax(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllreduceMax(comm, rank);
}
static void AllreduceMin(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllreduceMin(comm, rank);
}
static void AllreduceSum(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllreduceSum(comm);
}
static void AllreduceBitwiseAND(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllreduceBitwiseAND(comm, rank);
}
static void AllreduceBitwiseOR(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllreduceBitwiseOR(comm, rank);
}
static void AllreduceBitwiseXOR(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllreduceBitwiseXOR(comm, rank);
}
static void Broadcast(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyBroadcast(comm, rank);
}
static void Mixture(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
for (auto i = 0; i < 5; i++) {
VerifyAllgather(comm, rank);
VerifyAllreduceMax(comm, rank);
VerifyAllreduceMin(comm, rank);
VerifyAllreduceSum(comm);
VerifyAllreduceBitwiseAND(comm, rank);
VerifyAllreduceBitwiseOR(comm, rank);
VerifyAllreduceBitwiseXOR(comm, rank);
VerifyBroadcast(comm, rank);
}
}
protected:
static void VerifyAllgather(InMemoryCommunicator &comm, int rank) {
std::string input{static_cast<char>('0' + rank)};
auto output = comm.AllGather(input);
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(output[i], static_cast<char>('0' + i));
}
}
static void VerifyAllgatherV(InMemoryCommunicator &comm, int rank) {
std::vector<std::string_view> inputs{"a", "bb", "ccc"};
auto output = comm.AllGatherV(inputs[rank]);
EXPECT_EQ(output, "abbccc");
}
static void VerifyAllreduceMax(InMemoryCommunicator &comm, int rank) {
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax);
int expected[] = {3, 4, 5, 6, 7};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}
}
static void VerifyAllreduceMin(InMemoryCommunicator &comm, int rank) {
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMin);
int expected[] = {1, 2, 3, 4, 5};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}
}
static void VerifyAllreduceSum(InMemoryCommunicator &comm) {
int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
int expected[] = {3, 6, 9, 12, 15};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}
}
static void VerifyAllreduceBitwiseAND(InMemoryCommunicator &comm, int rank) {
std::bitset<2> original(rank);
auto buffer = original.to_ulong();
comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseAND);
EXPECT_EQ(buffer, 0UL);
}
static void VerifyAllreduceBitwiseOR(InMemoryCommunicator &comm, int rank) {
std::bitset<2> original(rank);
auto buffer = original.to_ulong();
comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseOR);
std::bitset<2> actual(buffer);
std::bitset<2> expected{0b11};
EXPECT_EQ(actual, expected);
}
static void VerifyAllreduceBitwiseXOR(InMemoryCommunicator &comm, int rank) {
std::bitset<3> original(rank * 2);
auto buffer = original.to_ulong();
comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseXOR);
std::bitset<3> actual(buffer);
std::bitset<3> expected{0b110};
EXPECT_EQ(actual, expected);
}
static void VerifyBroadcast(InMemoryCommunicator &comm, int rank) {
if (rank == 0) {
std::string buffer{"hello"};
comm.Broadcast(&buffer[0], buffer.size(), 0);
EXPECT_EQ(buffer, "hello");
} else {
std::string buffer{" "};
comm.Broadcast(&buffer[0], buffer.size(), 0);
EXPECT_EQ(buffer, "hello");
}
}
static int const kWorldSize{3};
};
TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
auto construct = []() { InMemoryCommunicator comm{0, 0}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooSmall) {
auto construct = []() { InMemoryCommunicator comm{1, -1}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooBig) {
auto construct = []() { InMemoryCommunicator comm{1, 1}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
auto construct = []() {
Json config{JsonObject()};
config["in_memory_world_size"] = std::string("1");
config["in_memory_rank"] = Integer(0);
auto *comm = InMemoryCommunicator::Create(config);
delete comm;
};
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankNotInteger) {
auto construct = []() {
Json config{JsonObject()};
config["in_memory_world_size"] = 1;
config["in_memory_rank"] = std::string("0");
auto *comm = InMemoryCommunicator::Create(config);
delete comm;
};
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(InMemoryCommunicatorSimpleTest, GetWorldSizeAndRank) {
InMemoryCommunicator comm{1, 0};
EXPECT_EQ(comm.GetWorldSize(), 1);
EXPECT_EQ(comm.GetRank(), 0);
}
TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
InMemoryCommunicator comm{1, 0};
EXPECT_TRUE(comm.IsDistributed());
}
TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); }
TEST_F(InMemoryCommunicatorTest, AllgatherV) { Verify(&AllgatherV); }
TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }
TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }
TEST_F(InMemoryCommunicatorTest, AllreduceSum) { Verify(&AllreduceSum); }
TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseAND) { Verify(&AllreduceBitwiseAND); }
TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseOR) { Verify(&AllreduceBitwiseOR); }
TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseXOR) { Verify(&AllreduceBitwiseXOR); }
TEST_F(InMemoryCommunicatorTest, Broadcast) { Verify(&Broadcast); }
TEST_F(InMemoryCommunicatorTest, Mixture) { Verify(&Mixture); }
} // namespace collective
} // namespace xgboost

View File

@@ -59,7 +59,7 @@ class LoopTest : public ::testing::Test {
TEST_F(LoopTest, Timeout) {
std::vector<std::int8_t> data(1);
Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &pair_.second, 0};
loop_->Submit(op);
loop_->Submit(std::move(op));
auto rc = loop_->Block();
ASSERT_FALSE(rc.OK());
ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)) << rc.Report();
@@ -75,8 +75,8 @@ TEST_F(LoopTest, Op) {
Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0};
Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0};
loop_->Submit(wop);
loop_->Submit(rop);
loop_->Submit(std::move(wop));
loop_->Submit(std::move(rop));
auto rc = loop_->Block();
SafeColl(rc);
@@ -90,7 +90,7 @@ TEST_F(LoopTest, Block) {
common::Timer t;
t.Start();
loop_->Submit(op);
loop_->Submit(std::move(op));
t.Stop();
// submit is non-blocking
ASSERT_LT(t.ElapsedSeconds(), 1);

View File

@@ -1,99 +0,0 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#ifdef XGBOOST_USE_NCCL
#include <gtest/gtest.h>
#include <bitset>
#include <string> // for string
#include "../../../src/collective/comm.cuh"
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/collective/nccl_device_communicator.cuh"
#include "../helpers.h"
namespace xgboost {
namespace collective {
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { NcclDeviceCommunicator comm{-1, false, DefaultNcclName()}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
auto stub = std::make_shared<NcclStub>(DefaultNcclName());
auto rc = stub->GetNcclResult(ncclSystemError);
auto msg = rc.Report();
ASSERT_TRUE(msg.find("environment variables") != std::string::npos);
}
namespace {
void VerifyAllReduceBitwiseAND() {
auto const rank = collective::GetRank();
std::bitset<64> original{};
original[rank] = true;
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, DeviceOrd::CUDA(rank));
collective::AllReduce<collective::Operation::kBitwiseAND>(rank, buffer.DevicePointer(), 1);
collective::Synchronize(rank);
EXPECT_EQ(buffer.HostVector()[0], 0ULL);
}
} // anonymous namespace
TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseAND) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseAND test with # GPUs = " << n_gpus;
}
auto constexpr kUseNccl = true;
RunWithInMemoryCommunicator<kUseNccl>(n_gpus, VerifyAllReduceBitwiseAND);
}
namespace {
void VerifyAllReduceBitwiseOR() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::bitset<64> original{};
original[rank] = true;
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, DeviceOrd::CUDA(rank));
collective::AllReduce<collective::Operation::kBitwiseOR>(rank, buffer.DevicePointer(), 1);
collective::Synchronize(rank);
EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1);
}
} // anonymous namespace
TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseOR) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseOR test with # GPUs = " << n_gpus;
}
auto constexpr kUseNccl = true;
RunWithInMemoryCommunicator<kUseNccl>(n_gpus, VerifyAllReduceBitwiseOR);
}
namespace {
void VerifyAllReduceBitwiseXOR() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::bitset<64> original{~0ULL};
original[rank] = false;
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, DeviceOrd::CUDA(rank));
collective::AllReduce<collective::Operation::kBitwiseXOR>(rank, buffer.DevicePointer(), 1);
collective::Synchronize(rank);
EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1);
}
} // anonymous namespace
TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseXOR) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseXOR test with # GPUs = " << n_gpus;
}
auto constexpr kUseNccl = true;
RunWithInMemoryCommunicator<kUseNccl>(n_gpus, VerifyAllReduceBitwiseXOR);
}
} // namespace collective
} // namespace xgboost
#endif // XGBOOST_USE_NCCL

View File

@@ -1,70 +0,0 @@
/**
* Copyright 2022-2024, XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../../../src/collective/rabit_communicator.h"
#include "../helpers.h"
namespace xgboost::collective {
TEST(RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
auto construct = []() { RabitCommunicator comm{0, 0}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooSmall) {
auto construct = []() { RabitCommunicator comm{1, -1}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooBig) {
auto construct = []() { RabitCommunicator comm{1, 1}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(RabitCommunicatorSimpleTest, GetWorldSizeAndRank) {
RabitCommunicator comm{6, 3};
EXPECT_EQ(comm.GetWorldSize(), 6);
EXPECT_EQ(comm.GetRank(), 3);
}
TEST(RabitCommunicatorSimpleTest, IsNotDistributed) {
RabitCommunicator comm{2, 1};
// Rabit is only distributed with a tracker.
EXPECT_FALSE(comm.IsDistributed());
}
namespace {
void VerifyVectorAllgatherV() {
auto n_workers = collective::GetWorldSize();
ASSERT_EQ(n_workers, 3);
auto rank = collective::GetRank();
// Construct input that has different length for each worker.
std::vector<std::vector<char>> inputs;
for (std::int32_t i = 0; i < rank + 1; ++i) {
std::vector<char> in;
for (std::int32_t j = 0; j < rank + 1; ++j) {
in.push_back(static_cast<char>(j));
}
inputs.emplace_back(std::move(in));
}
auto outputs = VectorAllgatherV(inputs);
ASSERT_EQ(outputs.size(), (1 + n_workers) * n_workers / 2);
auto const& res = outputs;
for (std::int32_t i = 0; i < n_workers; ++i) {
std::int32_t k = 0;
for (auto v : res[i]) {
ASSERT_EQ(v, k++);
}
}
}
} // namespace
TEST(VectorAllgatherV, Basic) {
std::int32_t n_workers{3};
RunWithInMemoryCommunicator(n_workers, VerifyVectorAllgatherV);
}
} // namespace xgboost::collective

View File

@@ -29,6 +29,7 @@ class PrintWorker : public WorkerForTest {
TEST_F(TrackerTest, Bootstrap) {
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
ASSERT_TRUE(HasTimeout(tracker.Timeout()));
ASSERT_FALSE(tracker.Ready());
auto fut = tracker.Run();
@@ -47,6 +48,9 @@ TEST_F(TrackerTest, Bootstrap) {
w.join();
}
SafeColl(fut.get());
ASSERT_FALSE(HasTimeout(std::chrono::seconds{-1}));
ASSERT_FALSE(HasTimeout(std::chrono::seconds{0}));
}
TEST_F(TrackerTest, Print) {

View File

@@ -16,6 +16,10 @@
#include "../../../src/collective/tracker.h" // for GetHostAddress
#include "../helpers.h" // for FileExists
#if defined(XGBOOST_USE_FEDERATED)
#include "../plugin/federated/test_worker.h"
#endif // defined(XGBOOST_USE_FEDERATED)
namespace xgboost::collective {
class WorkerForTest {
std::string tracker_host_;
@@ -45,6 +49,7 @@ class WorkerForTest {
if (i != comm_.Rank()) {
ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking());
ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK());
ASSERT_TRUE(comm_.Chan(i)->Socket()->SetNoDelay().OK());
}
}
}
@@ -126,15 +131,80 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
ASSERT_TRUE(fut.get().OK());
}
inline auto MakeDistributedTestConfig(std::string host, std::int32_t port,
std::chrono::seconds timeout, std::int32_t r) {
Json config{Object{}};
config["dmlc_communicator"] = std::string{"rabit"};
config["dmlc_tracker_uri"] = host;
config["dmlc_tracker_port"] = port;
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
config["dmlc_timeout"] = static_cast<std::int64_t>(timeout.count());
config["dmlc_task_id"] = std::to_string(r);
config["dmlc_retry"] = 2;
return config;
}
template <typename WorkerFn>
void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true) {
system::SocketStartup();
std::chrono::seconds timeout{1};
std::string host;
auto rc = GetHostAddress(&host);
SafeColl(rc);
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
auto fut = tracker.Run();
std::vector<std::thread> workers;
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] {
auto config = MakeDistributedTestConfig(host, port, timeout, i);
Init(config);
worker_fn();
if (need_finalize) {
Finalize();
}
});
}
for (auto& t : workers) {
t.join();
}
ASSERT_TRUE(fut.get().OK());
system::SocketFinalize();
}
class BaseMGPUTest : public ::testing::Test {
public:
/**
* @param emulate_if_single Emulate multi-GPU for federated test if there's only one GPU
* available.
*/
template <typename Fn>
auto DoTest(Fn&& fn, bool is_federated, bool emulate_if_single = false) const {
auto n_gpus = common::AllVisibleGPUs();
if (is_federated) {
#if defined(XGBOOST_USE_FEDERATED)
if (n_gpus == 1 && emulate_if_single) {
// Emulate multiple GPUs.
// We don't use nccl and can have multiple communicators running on the same device.
n_gpus = 3;
}
TestFederatedGlobal(n_gpus, fn);
#else
GTEST_SKIP_("Not compiled with federated learning.");
#endif // defined(XGBOOST_USE_FEDERATED)
} else {
#if defined(XGBOOST_USE_NCCL)
TestDistributedGlobal(n_gpus, fn);
#else
GTEST_SKIP_("Not compiled with NCCL.");
#endif // defined(XGBOOST_USE_NCCL)
}
}
};
} // namespace xgboost::collective