Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -108,6 +108,32 @@ TEST_F(FederatedCollTestGPU, Allreduce) {
});
}
TEST(FederatedCollGPUGlobal, Allreduce) {
std::int32_t n_workers = common::AllVisibleGPUs();
TestFederatedGlobal(n_workers, [&] {
auto r = collective::GetRank();
auto world = collective::GetWorldSize();
CHECK_EQ(n_workers, world);
dh::device_vector<std::uint32_t> values(3, r);
auto ctx = MakeCUDACtx(r);
auto rc = collective::Allreduce(
&ctx, linalg::MakeVec(values.data().get(), values.size(), DeviceOrd::CUDA(r)),
Op::kBitwiseOR);
SafeColl(rc);
std::vector<std::uint32_t> expected(values.size(), 0);
for (std::int32_t rank = 0; rank < world; ++rank) {
for (std::size_t i = 0; i < expected.size(); ++i) {
expected[i] |= rank;
}
}
for (std::size_t i = 0; i < expected.size(); ++i) {
CHECK_EQ(expected[i], values[i]);
}
});
}
TEST_F(FederatedCollTestGPU, Broadcast) {
std::int32_t n_workers = common::AllVisibleGPUs();
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {

View File

@@ -11,12 +11,24 @@
#include "../../../../plugin/federated/federated_tracker.h"
#include "../../../../src/collective/comm_group.h"
#include "../../../../src/collective/communicator-inl.h"
#include "federated_comm.h" // for FederatedComm
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
inline Json FederatedTestConfig(std::int32_t n_workers, std::int32_t port, std::int32_t i) {
Json config{Object{}};
config["dmlc_communicator"] = std::string{"federated"};
config["dmlc_task_id"] = std::to_string(i);
config["dmlc_retry"] = 2;
config["federated_world_size"] = n_workers;
config["federated_rank"] = i;
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
return config;
}
template <typename WorkerFn>
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) {
Json config{Object()};
config["federated_secure"] = Boolean{false};
config["n_workers"] = Integer{n_workers};
@@ -30,16 +42,7 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] {
Json config{Object{}};
config["federated_world_size"] = n_workers;
config["federated_rank"] = i;
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
auto comm = std::make_shared<FederatedComm>(
DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config);
fn(comm, i);
});
workers.emplace_back([=] { fn(port, i); });
}
for (auto& t : workers) {
@@ -51,39 +54,33 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
ASSERT_TRUE(fut.get().OK());
}
template <typename WorkerFn>
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
auto config = FederatedTestConfig(n_workers, port, i);
auto comm = std::make_shared<FederatedComm>(
DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config);
fn(comm, i);
});
}
template <typename WorkerFn>
void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) {
Json config{Object()};
config["federated_secure"] = Boolean{false};
config["n_workers"] = Integer{n_workers};
FederatedTracker tracker{config};
auto fut = tracker.Run();
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
auto config = FederatedTestConfig(n_workers, port, i);
std::shared_ptr<CommGroup> comm_group{CommGroup::Create(config)};
fn(comm_group, i);
});
}
std::vector<std::thread> workers;
auto rc = tracker.WaitUntilReady();
ASSERT_TRUE(rc.OK()) << rc.Report();
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] {
Json config{Object{}};
config["dmlc_communicator"] = std::string{"federated"};
config["dmlc_task_id"] = std::to_string(i);
config["dmlc_retry"] = 2;
config["federated_world_size"] = n_workers;
config["federated_rank"] = i;
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
std::shared_ptr<CommGroup> comm_group{CommGroup::Create(config)};
fn(comm_group, i);
});
}
for (auto& t : workers) {
t.join();
}
rc = tracker.Shutdown();
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_TRUE(fut.get().OK());
template <typename WorkerFn>
void TestFederatedGlobal(std::int32_t n_workers, WorkerFn&& fn) {
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
auto config = FederatedTestConfig(n_workers, port, i);
collective::Init(config);
fn();
collective::Finalize();
});
}
} // namespace xgboost::collective

View File

@@ -1,99 +0,0 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <dmlc/omp.h>
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>
#include <xgboost/json.h>
#include <random>
#include <thread> // for thread, sleep_for
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/communicator-inl.h"
#include "../../../src/common/threading_utils.h"
namespace xgboost {
class ServerForTest {
std::string server_address_;
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
public:
explicit ServerForTest(std::size_t world_size) {
server_thread_.reset(new std::thread([this, world_size] {
grpc::ServerBuilder builder;
xgboost::federated::FederatedService service{static_cast<std::int32_t>(world_size)};
int selected_port;
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
server_address_ = std::string("localhost:") + std::to_string(selected_port);
server_->Wait();
}));
}
~ServerForTest() {
using namespace std::chrono_literals;
while (!server_) {
std::this_thread::sleep_for(100ms);
}
server_->Shutdown();
while (!server_thread_) {
std::this_thread::sleep_for(100ms);
}
server_thread_->join();
}
auto Address() const {
using namespace std::chrono_literals;
while (server_address_.empty()) {
std::this_thread::sleep_for(100ms);
}
return server_address_;
}
};
class BaseFederatedTest : public ::testing::Test {
protected:
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
void TearDown() override { server_.reset(nullptr); }
static int constexpr kWorldSize{2};
std::unique_ptr<ServerForTest> server_;
};
template <typename Function, typename... Args>
void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address,
Function&& function, Args&&... args) {
auto run = [&](auto rank) {
Json config{JsonObject()};
config["xgboost_communicator"] = String("federated");
config["federated_secure"] = false;
config["federated_server_address"] = String(server_address);
config["federated_world_size"] = world_size;
config["federated_rank"] = rank;
xgboost::collective::Init(config);
std::forward<Function>(function)(std::forward<Args>(args)...);
xgboost::collective::Finalize();
};
#if defined(_OPENMP)
common::ParallelFor(world_size, world_size, run);
#else
std::vector<std::thread> threads;
for (auto rank = 0; rank < world_size; rank++) {
threads.emplace_back(run, rank);
}
for (auto& thread : threads) {
thread.join();
}
#endif
}
} // namespace xgboost

View File

@@ -1,97 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <thrust/host_vector.h>
#include <ctime>
#include <iostream>
#include <thread>
#include "../../../plugin/federated/federated_communicator.h"
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/collective/device_communicator_adapter.cuh"
#include "../helpers.h"
#include "./helpers.h"
namespace xgboost::collective {
class FederatedAdapterTest : public BaseFederatedTest {};
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { DeviceCommunicatorAdapter adapter{-1}; };
EXPECT_THROW(construct(), dmlc::Error);
}
namespace {
void VerifyAllReduceSum() {
auto const world_size = collective::GetWorldSize();
auto const device = GPUIDX;
int count = 3;
common::SetDevice(device);
thrust::device_vector<double> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
collective::AllReduce<collective::Operation::kSum>(device, buffer.data().get(), count);
thrust::host_vector<double> host_buffer = buffer;
EXPECT_EQ(host_buffer.size(), count);
for (auto i = 0; i < count; i++) {
EXPECT_EQ(host_buffer[i], i * world_size);
}
}
} // anonymous namespace
TEST_F(FederatedAdapterTest, MGPUAllReduceSum) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum);
}
namespace {
void VerifyAllGather() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const device = GPUIDX;
common::SetDevice(device);
thrust::device_vector<double> send_buffer(1, rank);
thrust::device_vector<double> receive_buffer(world_size, 0);
collective::AllGather(device, send_buffer.data().get(), receive_buffer.data().get(),
sizeof(double));
thrust::host_vector<double> host_buffer = receive_buffer;
EXPECT_EQ(host_buffer.size(), world_size);
for (auto i = 0; i < world_size; i++) {
EXPECT_EQ(host_buffer[i], i);
}
}
} // anonymous namespace
TEST_F(FederatedAdapterTest, MGPUAllGather) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGather);
}
namespace {
void VerifyAllGatherV() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const device = GPUIDX;
int const count = rank + 2;
common::SetDevice(device);
thrust::device_vector<char> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
std::vector<std::size_t> segments(world_size);
dh::caching_device_vector<char> receive_buffer{};
collective::AllGatherV(device, buffer.data().get(), count, &segments, &receive_buffer);
EXPECT_EQ(segments[0], 2);
EXPECT_EQ(segments[1], 3);
thrust::host_vector<char> host_buffer = receive_buffer;
EXPECT_EQ(host_buffer.size(), 5);
int expected[] = {0, 1, 0, 1, 2};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(host_buffer[i], expected[i]);
}
}
} // anonymous namespace
TEST_F(FederatedAdapterTest, MGPUAllGatherV) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGatherV);
}
} // namespace xgboost::collective

View File

@@ -1,161 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <iostream>
#include <thread>
#include "../../../plugin/federated/federated_communicator.h"
#include "helpers.h"
namespace xgboost::collective {
class FederatedCommunicatorTest : public BaseFederatedTest {
public:
static void VerifyAllgather(int rank, const std::string &server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckAllgather(comm, rank);
}
static void VerifyAllgatherV(int rank, const std::string &server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckAllgatherV(comm, rank);
}
static void VerifyAllreduce(int rank, const std::string &server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckAllreduce(comm);
}
static void VerifyBroadcast(int rank, const std::string &server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckBroadcast(comm, rank);
}
protected:
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
std::string input{static_cast<char>('0' + rank)};
auto output = comm.AllGather(input);
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(output[i], static_cast<char>('0' + i));
}
}
static void CheckAllgatherV(FederatedCommunicator &comm, int rank) {
std::vector<std::string_view> inputs{"Federated", " Learning!!!"};
auto output = comm.AllGatherV(inputs[rank]);
EXPECT_EQ(output, "Federated Learning!!!");
}
static void CheckAllreduce(FederatedCommunicator &comm) {
int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
int expected[] = {2, 4, 6, 8, 10};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}
}
static void CheckBroadcast(FederatedCommunicator &comm, int rank) {
if (rank == 0) {
std::string buffer{"hello"};
comm.Broadcast(&buffer[0], buffer.size(), 0);
EXPECT_EQ(buffer, "hello");
} else {
std::string buffer{" "};
comm.Broadcast(&buffer[0], buffer.size(), 0);
EXPECT_EQ(buffer, "hello");
}
}
};
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
auto construct = [] { FederatedCommunicator comm{0, 0, "localhost:0", "", "", ""}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) {
auto construct = [] { FederatedCommunicator comm{1, -1, "localhost:0", "", "", ""}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) {
auto construct = [] { FederatedCommunicator comm{1, 1, "localhost:0", "", "", ""}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
auto construct = [] {
Json config{JsonObject()};
config["federated_server_address"] = std::string("localhost:0");
config["federated_world_size"] = std::string("1");
config["federated_rank"] = Integer(0);
FederatedCommunicator::Create(config);
};
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
auto construct = [] {
Json config{JsonObject()};
config["federated_server_address"] = std::string("localhost:0");
config["federated_world_size"] = 1;
config["federated_rank"] = std::string("0");
FederatedCommunicator::Create(config);
};
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
FederatedCommunicator comm{6, 3, "localhost:0"};
EXPECT_EQ(comm.GetWorldSize(), 6);
EXPECT_EQ(comm.GetRank(), 3);
}
TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
FederatedCommunicator comm{2, 1, "localhost:0"};
EXPECT_TRUE(comm.IsDistributed());
}
TEST_F(FederatedCommunicatorTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_->Address());
}
for (auto &thread : threads) {
thread.join();
}
}
TEST_F(FederatedCommunicatorTest, AllgatherV) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgatherV, rank, server_->Address());
}
for (auto &thread : threads) {
thread.join();
}
}
TEST_F(FederatedCommunicatorTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_->Address());
}
for (auto &thread : threads) {
thread.join();
}
}
TEST_F(FederatedCommunicatorTest, Broadcast) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_->Address());
}
for (auto &thread : threads) {
thread.join();
}
}
} // namespace xgboost::collective

View File

@@ -6,16 +6,13 @@
#include <thread>
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/communicator-inl.h"
#include "../filesystem.h"
#include "../helpers.h"
#include "helpers.h"
#include "federated/test_worker.h"
namespace xgboost {
class FederatedDataTest : public BaseFederatedTest {};
void VerifyLoadUri() {
auto const rank = collective::GetRank();
@@ -47,7 +44,8 @@ void VerifyLoadUri() {
}
}
TEST_F(FederatedDataTest, LoadUri) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLoadUri);
TEST(FederatedDataTest, LoadUri) {
static int constexpr kWorldSize{2};
collective::TestFederatedGlobal(kWorldSize, [] { VerifyLoadUri(); });
}
} // namespace xgboost

View File

@@ -1,17 +1,19 @@
/*!
* Copyright 2023 XGBoost contributors
/**
* Copyright 2023-2024, XGBoost contributors
*
* Some other tests for federated learning are in the main test suite (test_learner.cc),
* gaurded by the `XGBOOST_USE_FEDERATED`.
*/
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <xgboost/objective.h>
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/communicator-inl.h"
#include "../../../src/common/linalg_op.h"
#include "../../../src/common/linalg_op.h" // for begin, end
#include "../helpers.h"
#include "../objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator
#include "helpers.h"
#include "federated/test_worker.h"
namespace xgboost {
namespace {
@@ -36,32 +38,16 @@ auto MakeModel(std::string tree_method, std::string device, std::string objectiv
return model;
}
void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model,
std::string tree_method, std::string device, std::string objective) {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
void VerifyObjective(std::size_t rows, std::size_t cols, float expected_base_score,
Json expected_model, std::string const &tree_method, std::string device,
std::string const &objective) {
auto rank = collective::GetRank();
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
if (rank == 0) {
auto &h_upper = dmat->Info().labels_upper_bound_.HostVector();
auto &h_lower = dmat->Info().labels_lower_bound_.HostVector();
h_lower.resize(rows);
h_upper.resize(rows);
for (size_t i = 0; i < rows; ++i) {
h_lower[i] = 1;
h_upper[i] = 10;
}
if (objective.find("rank:") != std::string::npos) {
auto h_label = dmat->Info().labels.HostView();
std::size_t k = 0;
for (auto &v : h_label) {
v = k % 2 == 0;
++k;
}
}
MakeLabelForObjTest(dmat, objective);
}
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(collective::GetWorldSize(), rank)};
auto model = MakeModel(tree_method, device, objective, sliced);
auto base_score = GetBaseScore(model);
@@ -71,18 +57,15 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e
} // namespace
class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string> {
std::unique_ptr<ServerForTest> server_;
static int constexpr kWorldSize{3};
protected:
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
void TearDown() override { server_.reset(nullptr); }
void Run(std::string tree_method, std::string device, std::string objective) {
static auto constexpr kRows{16};
static auto constexpr kCols{16};
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
MakeLabelForObjTest(dmat, objective);
auto &h_upper = dmat->Info().labels_upper_bound_.HostVector();
auto &h_lower = dmat->Info().labels_lower_bound_.HostVector();
@@ -103,9 +86,9 @@ class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string
auto model = MakeModel(tree_method, device, objective, dmat);
auto score = GetBaseScore(model);
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols,
score, model, tree_method, device, objective);
collective::TestFederatedGlobal(kWorldSize, [&]() {
VerifyObjective(kRows, kCols, score, model, tree_method, device, objective);
});
}
};

View File

@@ -1,243 +0,0 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../metric/test_auc.h"
#include "../metric/test_elementwise_metric.h"
#include "../metric/test_multiclass_metric.h"
#include "../metric/test_rank_metric.h"
#include "../metric/test_survival_metric.h"
#include "helpers.h"
namespace {
class FederatedMetricTest : public xgboost::BaseFederatedTest {};
} // anonymous namespace
namespace xgboost {
namespace metric {
TEST_F(FederatedMetricTest, BinaryAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, BinaryAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiClassAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiClassAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, RankingAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, RankingAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, PRAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, PRAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiClassPRAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiClassPRAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, RankingPRAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, RankingPRAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, RMSERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, RMSEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, RMSLERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, RMSLEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MAERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MAEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MAPERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MAPEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MPHERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MPHEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, LogLossRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, LogLossColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, ErrorRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, ErrorColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, PoissonNegLogLikRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, PoissonNegLogLikColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiRMSERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiRMSEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, QuantileRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, QuantileColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiClassErrorRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiClassErrorColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiClassLogLossRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiClassLogLossColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, PrecisionRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, PrecisionColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, NDCGRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, NDCGColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MAPRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MAPColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, NDCGExpGainRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, NDCGExpGainColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
DataSplitMode::kCol);
}
} // namespace metric
} // namespace xgboost
namespace xgboost {
namespace common {
TEST_F(FederatedMetricTest, AFTNegLogLikRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, AFTNegLogLikColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
DataSplitMode::kCol);
}
} // namespace common
} // namespace xgboost

View File

@@ -1,133 +0,0 @@
/*!
* Copyright 2017-2020 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <iostream>
#include <thread>
#include "federated_client.h"
#include "helpers.h"
namespace xgboost {
class FederatedServerTest : public BaseFederatedTest {
public:
static void VerifyAllgather(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
CheckAllgather(client, rank);
}
static void VerifyAllgatherV(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
CheckAllgatherV(client, rank);
}
static void VerifyAllreduce(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
CheckAllreduce(client);
}
static void VerifyBroadcast(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
CheckBroadcast(client, rank);
}
static void VerifyMixture(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
for (auto i = 0; i < 10; i++) {
CheckAllgather(client, rank);
CheckAllreduce(client);
CheckBroadcast(client, rank);
}
}
protected:
static void CheckAllgather(federated::FederatedClient& client, int rank) {
int data[] = {rank};
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allgather(send_buffer);
auto const* result = reinterpret_cast<int const*>(reply.data());
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(result[i], i);
}
}
static void CheckAllgatherV(federated::FederatedClient& client, int rank) {
std::vector<std::string_view> inputs{"Hello,", " World!"};
auto reply = client.AllgatherV(inputs[rank]);
EXPECT_EQ(reply, "Hello, World!");
}
static void CheckAllreduce(federated::FederatedClient& client) {
int data[] = {1, 2, 3, 4, 5};
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM);
auto const* result = reinterpret_cast<int const*>(reply.data());
int expected[] = {2, 4, 6, 8, 10};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(result[i], expected[i]);
}
}
static void CheckBroadcast(federated::FederatedClient& client, int rank) {
std::string send_buffer{};
if (rank == 0) {
send_buffer = "hello broadcast";
}
auto reply = client.Broadcast(send_buffer, 0);
EXPECT_EQ(reply, "hello broadcast") << "rank " << rank;
}
};
TEST_F(FederatedServerTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_->Address());
}
for (auto& thread : threads) {
thread.join();
}
}
TEST_F(FederatedServerTest, AllgatherV) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedServerTest::VerifyAllgatherV, rank, server_->Address());
}
for (auto& thread : threads) {
thread.join();
}
}
TEST_F(FederatedServerTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_->Address());
}
for (auto& thread : threads) {
thread.join();
}
}
TEST_F(FederatedServerTest, Broadcast) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_->Address());
}
for (auto& thread : threads) {
thread.join();
}
}
TEST_F(FederatedServerTest, Mixture) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_->Address());
}
for (auto& thread : threads) {
thread.join();
}
}
} // namespace xgboost