initial merge

This commit is contained in:
amdsc21
2023-03-25 04:31:55 +01:00
146 changed files with 6730 additions and 4082 deletions

View File

@@ -1,19 +0,0 @@
#include <chrono>
#include <thread>
#include <random>
#include <cstdint>
#include "helpers.h"
using namespace std::chrono_literals;
int GenerateRandomPort(int low, int high) {
// Ensure unique timestamp by introducing a small artificial delay
std::this_thread::sleep_for(100ms);
auto timestamp = static_cast<uint64_t>(std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch()).count());
std::mt19937_64 rng(timestamp);
std::uniform_int_distribution<int> dist(low, high);
int port = dist(rng);
return port;
}

View File

@@ -1,10 +1,69 @@
/*!
* Copyright 2022 XGBoost contributors
* Copyright 2022-2023 XGBoost contributors
*/
#pragma once
#ifndef XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_
#define XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>
#include <xgboost/json.h>
int GenerateRandomPort(int low, int high);
#include <random>
#endif // XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/communicator-inl.h"
inline int GenerateRandomPort(int low, int high) {
using namespace std::chrono_literals;
// Ensure unique timestamp by introducing a small artificial delay
std::this_thread::sleep_for(100ms);
auto timestamp = static_cast<uint64_t>(std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
std::mt19937_64 rng(timestamp);
std::uniform_int_distribution<int> dist(low, high);
int port = dist(rng);
return port;
}
inline std::string GetServerAddress() {
int port = GenerateRandomPort(50000, 60000);
std::string address = std::string("localhost:") + std::to_string(port);
return address;
}
namespace xgboost {
class BaseFederatedTest : public ::testing::Test {
protected:
void SetUp() override {
server_address_ = GetServerAddress();
server_thread_.reset(new std::thread([this] {
grpc::ServerBuilder builder;
xgboost::federated::FederatedService service{kWorldSize};
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
server_->Wait();
}));
}
void TearDown() override {
server_->Shutdown();
server_thread_->join();
}
void InitCommunicator(int rank) {
Json config{JsonObject()};
config["xgboost_communicator"] = String("federated");
config["federated_server_address"] = String(server_address_);
config["federated_world_size"] = kWorldSize;
config["federated_rank"] = rank;
xgboost::collective::Init(config);
}
static int const kWorldSize{3};
std::string server_address_;
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
};
} // namespace xgboost

View File

@@ -1,56 +1,20 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>
#include <thrust/host_vector.h>
#include <ctime>
#include <iostream>
#include <thread>
#include <ctime>
#include "./helpers.h"
#include "../../../plugin/federated/federated_communicator.h"
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/device_communicator_adapter.cuh"
#include "./helpers.h"
namespace {
namespace xgboost::collective {
std::string GetServerAddress() {
int port = GenerateRandomPort(50000, 60000);
std::string address = std::string("localhost:") + std::to_string(port);
return address;
}
} // anonymous namespace
namespace xgboost {
namespace collective {
class FederatedAdapterTest : public ::testing::Test {
protected:
void SetUp() override {
server_address_ = GetServerAddress();
server_thread_.reset(new std::thread([this] {
grpc::ServerBuilder builder;
federated::FederatedService service{kWorldSize};
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
server_->Wait();
}));
}
void TearDown() override {
server_->Shutdown();
server_thread_->join();
}
static int const kWorldSize{2};
std::string server_address_;
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
};
class FederatedAdapterTest : public BaseFederatedTest {};
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; };
@@ -65,20 +29,20 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) {
TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread([rank, server_address=server_address_] {
threads.emplace_back([rank, server_address = server_address_] {
FederatedCommunicator comm{kWorldSize, rank, server_address};
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
DeviceCommunicatorAdapter adapter{0, &comm};
int const count = 3;
int count = 3;
thrust::device_vector<double> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
adapter.AllReduceSum(buffer.data().get(), count);
thrust::host_vector<double> host_buffer = buffer;
EXPECT_EQ(host_buffer.size(), count);
for (auto i = 0; i < count; i++) {
EXPECT_EQ(host_buffer[i], i * 2);
EXPECT_EQ(host_buffer[i], i * kWorldSize);
}
}));
});
}
for (auto& thread : threads) {
thread.join();
@@ -88,7 +52,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread([rank, server_address=server_address_] {
threads.emplace_back([rank, server_address = server_address_] {
FederatedCommunicator comm{kWorldSize, rank, server_address};
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
DeviceCommunicatorAdapter adapter{0, &comm};
@@ -104,17 +68,16 @@ TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
EXPECT_EQ(segments[0], 2);
EXPECT_EQ(segments[1], 3);
thrust::host_vector<char> host_buffer = receive_buffer;
EXPECT_EQ(host_buffer.size(), 5);
int expected[] = {0, 1, 0, 1, 2};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(host_buffer.size(), 9);
int expected[] = {0, 1, 0, 1, 2, 0, 1, 2, 3};
for (auto i = 0; i < 9; i++) {
EXPECT_EQ(host_buffer[i], expected[i]);
}
}));
});
}
for (auto& thread : threads) {
thread.join();
}
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@@ -2,65 +2,34 @@
* Copyright 2022 XGBoost contributors
*/
#include <dmlc/parameter.h>
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>
#include <iostream>
#include <thread>
#include <ctime>
#include "helpers.h"
#include "../../../plugin/federated/federated_communicator.h"
#include "../../../plugin/federated/federated_server.h"
#include "helpers.h"
namespace {
namespace xgboost::collective {
std::string GetServerAddress() {
int port = GenerateRandomPort(50000, 60000);
std::string address = std::string("localhost:") + std::to_string(port);
return address;
}
} // anonymous namespace
namespace xgboost {
namespace collective {
class FederatedCommunicatorTest : public ::testing::Test {
class FederatedCommunicatorTest : public BaseFederatedTest {
public:
static void VerifyAllgather(int rank, const std::string& server_address) {
static void VerifyAllgather(int rank, const std::string &server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckAllgather(comm, rank);
}
static void VerifyAllreduce(int rank, const std::string& server_address) {
static void VerifyAllreduce(int rank, const std::string &server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckAllreduce(comm);
}
static void VerifyBroadcast(int rank, const std::string& server_address) {
static void VerifyBroadcast(int rank, const std::string &server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckBroadcast(comm, rank);
}
protected:
void SetUp() override {
server_address_ = GetServerAddress();
server_thread_.reset(new std::thread([this] {
grpc::ServerBuilder builder;
federated::FederatedService service{kWorldSize};
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
server_->Wait();
}));
}
void TearDown() override {
server_->Shutdown();
server_thread_->join();
}
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
int buffer[kWorldSize] = {0, 0, 0};
buffer[rank] = rank;
@@ -90,11 +59,6 @@ class FederatedCommunicatorTest : public ::testing::Test {
EXPECT_EQ(buffer, "hello");
}
}
static int const kWorldSize{3};
std::string server_address_;
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
};
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
@@ -161,8 +125,7 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
TEST_F(FederatedCommunicatorTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(
std::thread(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_));
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_);
}
for (auto &thread : threads) {
thread.join();
@@ -172,8 +135,7 @@ TEST_F(FederatedCommunicatorTest, Allgather) {
TEST_F(FederatedCommunicatorTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(
std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_));
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_);
}
for (auto &thread : threads) {
thread.join();
@@ -183,12 +145,10 @@ TEST_F(FederatedCommunicatorTest, Allreduce) {
TEST_F(FederatedCommunicatorTest, Broadcast) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(
std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_));
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_);
}
for (auto &thread : threads) {
thread.join();
}
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@@ -0,0 +1,65 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <fstream>
#include <iostream>
#include <thread>
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/communicator-inl.h"
#include "../filesystem.h"
#include "../helpers.h"
#include "helpers.h"
namespace xgboost {
class FederatedDataTest : public BaseFederatedTest {
public:
void VerifyLoadUri(int rank) {
InitCommunicator(rank);
size_t constexpr kRows{16};
size_t const kCols = 8 + rank;
dmlc::TemporaryDirectory tmpdir;
std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv";
CreateTestCSV(path, kRows, kCols);
std::unique_ptr<DMatrix> dmat;
std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
ASSERT_EQ(dmat->Info().num_col_, 8 * kWorldSize + 3);
ASSERT_EQ(dmat->Info().num_row_, kRows);
for (auto const& page : dmat->GetBatches<SparsePage>()) {
auto entries = page.GetView().data;
auto index = 0;
int offsets[] = {0, 8, 17};
int offset = offsets[rank];
for (auto row = 0; row < kRows; row++) {
for (auto col = 0; col < kCols; col++) {
EXPECT_EQ(entries[index].index, col + offset);
index++;
}
}
}
xgboost::collective::Finalize();
}
};
TEST_F(FederatedDataTest, LoadUri) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedDataTest_LoadUri_Test::VerifyLoadUri, this, rank);
}
for (auto& thread : threads) {
thread.join();
}
}
} // namespace xgboost

View File

@@ -1,30 +1,17 @@
/*!
* Copyright 2017-2020 XGBoost contributors
*/
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>
#include <ctime>
#include <iostream>
#include <thread>
#include "federated_client.h"
#include "federated_server.h"
#include "helpers.h"
namespace {
std::string GetServerAddress() {
int port = GenerateRandomPort(50000, 60000);
std::string address = std::string("localhost:") + std::to_string(port);
return address;
}
} // anonymous namespace
namespace xgboost {
class FederatedServerTest : public ::testing::Test {
class FederatedServerTest : public BaseFederatedTest {
public:
static void VerifyAllgather(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
@@ -51,23 +38,6 @@ class FederatedServerTest : public ::testing::Test {
}
protected:
void SetUp() override {
server_address_ = GetServerAddress();
server_thread_.reset(new std::thread([this] {
grpc::ServerBuilder builder;
federated::FederatedService service{kWorldSize};
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
server_->Wait();
}));
}
void TearDown() override {
server_->Shutdown();
server_thread_->join();
}
static void CheckAllgather(federated::FederatedClient& client, int rank) {
int data[kWorldSize] = {0, 0, 0};
data[rank] = rank;
@@ -98,17 +68,12 @@ class FederatedServerTest : public ::testing::Test {
auto reply = client.Broadcast(send_buffer, 0);
EXPECT_EQ(reply, "hello broadcast") << "rank " << rank;
}
static int const kWorldSize{3};
std::string server_address_;
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
};
TEST_F(FederatedServerTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_));
threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_address_);
}
for (auto& thread : threads) {
thread.join();
@@ -118,7 +83,7 @@ TEST_F(FederatedServerTest, Allgather) {
TEST_F(FederatedServerTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllreduce, rank, server_address_));
threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_address_);
}
for (auto& thread : threads) {
thread.join();
@@ -128,7 +93,7 @@ TEST_F(FederatedServerTest, Allreduce) {
TEST_F(FederatedServerTest, Broadcast) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyBroadcast, rank, server_address_));
threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_address_);
}
for (auto& thread : threads) {
thread.join();
@@ -138,7 +103,7 @@ TEST_F(FederatedServerTest, Broadcast) {
TEST_F(FederatedServerTest, Mixture) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyMixture, rank, server_address_));
threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_address_);
}
for (auto& thread : threads) {
thread.join();