Add Allgather to collective communicator (#8765)

* Add Allgather to collective communicator
This commit is contained in:
Rong Ou 2023-02-08 19:31:22 -08:00 committed by GitHub
parent 48cefa012e
commit cbf98cb9c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 187 additions and 4 deletions

View File

@ -6,6 +6,7 @@ syntax = "proto3";
package xgboost.federated; package xgboost.federated;
service Federated { service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
} }
@ -30,6 +31,17 @@ enum ReduceOperation {
BITWISE_XOR = 5; BITWISE_XOR = 5;
} }
message AllgatherRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherReply {
bytes receive_buffer = 1;
}
message AllreduceRequest { message AllreduceRequest {
// An incrementing counter that is unique to each round to operations. // An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1; uint64 sequence_number = 1;

View File

@ -46,6 +46,25 @@ class FederatedClient {
}()}, }()},
rank_{rank} {} rank_{rank} {}
std::string Allgather(std::string const &send_buffer) {
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
AllgatherReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub_->Allgather(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allgather RPC failed");
}
}
std::string Allreduce(std::string const &send_buffer, DataType data_type, std::string Allreduce(std::string const &send_buffer, DataType data_type,
ReduceOperation reduce_operation) { ReduceOperation reduce_operation) {
AllreduceRequest request; AllreduceRequest request;

View File

@ -126,6 +126,17 @@ class FederatedCommunicator : public Communicator {
*/ */
bool IsFederated() const override { return true; } bool IsFederated() const override { return true; }
/**
* \brief Perform in-place allgather.
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param size Number of bytes to be gathered.
*/
void AllGather(void *send_receive_buffer, std::size_t size) override {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
auto const received = client_->Allgather(send_buffer);
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
}
/** /**
* \brief Perform in-place allreduce. * \brief Perform in-place allreduce.
* \param send_receive_buffer Buffer for both sending and receiving data. * \param send_receive_buffer Buffer for both sending and receiving data.

View File

@ -14,6 +14,13 @@
namespace xgboost { namespace xgboost {
namespace federated { namespace federated {
grpc::Status FederatedService::Allgather(grpc::ServerContext* context,
AllgatherRequest const* request, AllgatherReply* reply) {
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context, grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
AllreduceRequest const* request, AllreduceReply* reply) { AllreduceRequest const* request, AllreduceReply* reply) {
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(), handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),

View File

@ -14,6 +14,9 @@ class FederatedService final : public Federated::Service {
public: public:
explicit FederatedService(int const world_size) : handler_{world_size} {} explicit FederatedService(int const world_size) : handler_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override; AllreduceReply* reply) override;

View File

@ -122,6 +122,17 @@ class Communicator {
/** @brief Whether the communicator is running in federated mode. */ /** @brief Whether the communicator is running in federated mode. */
virtual bool IsFederated() const = 0; virtual bool IsFederated() const = 0;
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size, and input data has been sliced into the
* corresponding position.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
*/
virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0;
/** /**
* @brief Combines values from all processes and distributes the result back to all processes. * @brief Combines values from all processes and distributes the result back to all processes.
* *

View File

@ -60,6 +60,13 @@ class InMemoryCommunicator : public Communicator {
bool IsDistributed() const override { return true; } bool IsDistributed() const override { return true; }
bool IsFederated() const override { return false; } bool IsFederated() const override { return false; }
void AllGather(void* in_out, std::size_t size) override {
std::string output;
handler_.Allgather(static_cast<const char*>(in_out), size, &output, sequence_number_++,
GetRank());
output.copy(static_cast<char*>(in_out), size);
}
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override { void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
auto const bytes = size * GetTypeSize(data_type); auto const bytes = size * GetTypeSize(data_type);
std::string output; std::string output;

View File

@ -9,6 +9,32 @@
namespace xgboost { namespace xgboost {
namespace collective { namespace collective {
/**
* @brief Functor for allgather.
*/
class AllgatherFunctor {
public:
std::string const name{"Allgather"};
AllgatherFunctor(int world_size, int rank) : world_size_{world_size}, rank_{rank} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (buffer->empty()) {
// Copy the input if this is the first request.
buffer->assign(input, bytes);
} else {
// Splice the input into the common buffer.
auto const per_rank = bytes / world_size_;
auto const index = rank_ * per_rank;
buffer->replace(index, per_rank, input + index, per_rank);
}
}
private:
int world_size_;
int rank_;
};
/** /**
* @brief Functor for allreduce. * @brief Functor for allreduce.
*/ */
@ -17,7 +43,7 @@ class AllreduceFunctor {
std::string const name{"Allreduce"}; std::string const name{"Allreduce"};
AllreduceFunctor(DataType dataType, Operation operation) AllreduceFunctor(DataType dataType, Operation operation)
: data_type_(dataType), operation_(operation) {} : data_type_{dataType}, operation_{operation} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const { void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (buffer->empty()) { if (buffer->empty()) {
@ -128,7 +154,7 @@ class BroadcastFunctor {
public: public:
std::string const name{"Broadcast"}; std::string const name{"Broadcast"};
BroadcastFunctor(int rank, int root) : rank_(rank), root_(root) {} BroadcastFunctor(int rank, int root) : rank_{rank}, root_{root} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const { void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (rank_ == root_) { if (rank_ == root_) {
@ -167,6 +193,11 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
cv_.notify_all(); cv_.notify_all();
} }
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
}
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output, void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank, DataType data_type, std::size_t sequence_number, int rank, DataType data_type,
Operation op) { Operation op) {

View File

@ -53,6 +53,17 @@ class InMemoryHandler {
*/ */
void Shutdown(uint64_t sequence_number, int rank); void Shutdown(uint64_t sequence_number, int rank);
/**
* @brief Perform allgather.
* @param input The input buffer.
* @param bytes Number of bytes in the input buffer.
* @param output The output buffer.
* @param sequence_number Call sequence number.
* @param rank Index of the worker.
*/
void Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank);
/** /**
* @brief Perform allreduce. * @brief Perform allreduce.
* @param input The input buffer. * @param input The input buffer.

View File

@ -17,6 +17,7 @@ class NoOpCommunicator : public Communicator {
NoOpCommunicator() : Communicator(1, 0) {} NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; } bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; } bool IsFederated() const override { return false; }
void AllGather(void *, std::size_t) override {}
void AllReduce(void *, std::size_t, DataType, Operation) override {} void AllReduce(void *, std::size_t, DataType, Operation) override {}
void Broadcast(void *, std::size_t, int) override {} void Broadcast(void *, std::size_t, int) override {}
std::string GetProcessorName() override { return ""; } std::string GetProcessorName() override { return ""; }

View File

@ -55,6 +55,12 @@ class RabitCommunicator : public Communicator {
bool IsFederated() const override { return false; } bool IsFederated() const override { return false; }
void AllGather(void *send_receive_buffer, std::size_t size) override {
auto const per_rank = size / GetWorldSize();
auto const index = per_rank * GetRank();
rabit::Allgather(static_cast<char *>(send_receive_buffer), size, index, per_rank, per_rank);
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override { Operation op) override {
switch (data_type) { switch (data_type) {

View File

@ -24,6 +24,16 @@ class InMemoryCommunicatorTest : public ::testing::Test {
} }
} }
static void Allgather(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
char buffer[kWorldSize] = {'a', 'b', 'c'};
buffer[rank] = '0' + rank;
comm.AllGather(buffer, kWorldSize);
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(buffer[i], '0' + i);
}
}
static void AllreduceMax(int rank) { static void AllreduceMax(int rank) {
InMemoryCommunicator comm{kWorldSize, rank}; InMemoryCommunicator comm{kWorldSize, rank};
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank}; int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
@ -147,6 +157,8 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
EXPECT_TRUE(comm.IsDistributed()); EXPECT_TRUE(comm.IsDistributed());
} }
TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); }
TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); } TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }
TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); } TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }

View File

@ -28,6 +28,11 @@ namespace collective {
class FederatedCommunicatorTest : public ::testing::Test { class FederatedCommunicatorTest : public ::testing::Test {
public: public:
static void VerifyAllgather(int rank, const std::string& server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckAllgather(comm, rank);
}
static void VerifyAllreduce(int rank, const std::string& server_address) { static void VerifyAllreduce(int rank, const std::string& server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address}; FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckAllreduce(comm); CheckAllreduce(comm);
@ -56,6 +61,15 @@ class FederatedCommunicatorTest : public ::testing::Test {
server_thread_->join(); server_thread_->join();
} }
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
int buffer[kWorldSize] = {0, 0, 0};
buffer[rank] = rank;
comm.AllGather(buffer, sizeof(buffer));
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(buffer[i], i);
}
}
static void CheckAllreduce(FederatedCommunicator &comm) { static void CheckAllreduce(FederatedCommunicator &comm) {
int buffer[] = {1, 2, 3, 4, 5}; int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
@ -144,6 +158,17 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
EXPECT_TRUE(comm.IsDistributed()); EXPECT_TRUE(comm.IsDistributed());
} }
TEST_F(FederatedCommunicatorTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(
std::thread(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_));
}
for (auto &thread : threads) {
thread.join();
}
}
TEST_F(FederatedCommunicatorTest, Allreduce) { TEST_F(FederatedCommunicatorTest, Allreduce) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) { for (auto rank = 0; rank < kWorldSize; rank++) {

View File

@ -4,13 +4,13 @@
#include <grpcpp/server_builder.h> #include <grpcpp/server_builder.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <ctime>
#include <iostream> #include <iostream>
#include <thread> #include <thread>
#include <ctime>
#include "helpers.h"
#include "federated_client.h" #include "federated_client.h"
#include "federated_server.h" #include "federated_server.h"
#include "helpers.h"
namespace { namespace {
@ -26,6 +26,11 @@ namespace xgboost {
class FederatedServerTest : public ::testing::Test { class FederatedServerTest : public ::testing::Test {
public: public:
static void VerifyAllgather(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
CheckAllgather(client, rank);
}
static void VerifyAllreduce(int rank, const std::string& server_address) { static void VerifyAllreduce(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank}; federated::FederatedClient client{server_address, rank};
CheckAllreduce(client); CheckAllreduce(client);
@ -39,6 +44,7 @@ class FederatedServerTest : public ::testing::Test {
static void VerifyMixture(int rank, const std::string& server_address) { static void VerifyMixture(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank}; federated::FederatedClient client{server_address, rank};
for (auto i = 0; i < 10; i++) { for (auto i = 0; i < 10; i++) {
CheckAllgather(client, rank);
CheckAllreduce(client); CheckAllreduce(client);
CheckBroadcast(client, rank); CheckBroadcast(client, rank);
} }
@ -62,6 +68,17 @@ class FederatedServerTest : public ::testing::Test {
server_thread_->join(); server_thread_->join();
} }
static void CheckAllgather(federated::FederatedClient& client, int rank) {
int data[kWorldSize] = {0, 0, 0};
data[rank] = rank;
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allgather(send_buffer);
auto const* result = reinterpret_cast<int const*>(reply.data());
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(result[i], i);
}
}
static void CheckAllreduce(federated::FederatedClient& client) { static void CheckAllreduce(federated::FederatedClient& client) {
int data[] = {1, 2, 3, 4, 5}; int data[] = {1, 2, 3, 4, 5};
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data)); std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
@ -88,6 +105,16 @@ class FederatedServerTest : public ::testing::Test {
std::unique_ptr<grpc::Server> server_; std::unique_ptr<grpc::Server> server_;
}; };
TEST_F(FederatedServerTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_));
}
for (auto& thread : threads) {
thread.join();
}
}
TEST_F(FederatedServerTest, Allreduce) { TEST_F(FederatedServerTest, Allreduce) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) { for (auto rank = 0; rank < kWorldSize; rank++) {