Add an in-memory collective communicator (#8494)

This commit is contained in:
Rong Ou
2022-11-30 08:24:12 -08:00
committed by GitHub
parent 157e98edf7
commit a8255ea678
15 changed files with 577 additions and 277 deletions

View File

@@ -6,7 +6,6 @@ syntax = "proto3";
package xgboost.federated;
service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}
@@ -28,17 +27,6 @@ enum ReduceOperation {
SUM = 2;
}
message AllgatherRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherReply {
bytes receive_buffer = 1;
}
message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;

View File

@@ -46,25 +46,6 @@ class FederatedClient {
}()},
rank_{rank} {}
std::string Allgather(std::string const &send_buffer) {
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
AllgatherReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub_->Allgather(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allgather RPC failed");
}
}
std::string Allreduce(std::string const &send_buffer, DataType data_type,
ReduceOperation reduce_operation) {
AllreduceRequest request;

View File

@@ -12,40 +12,6 @@
namespace xgboost {
namespace collective {
/** @brief Get the size of the data type. */
inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt8:
size = sizeof(std::int8_t);
break;
case DataType::kUInt8:
size = sizeof(std::uint8_t);
break;
case DataType::kInt32:
size = sizeof(std::int32_t);
break;
case DataType::kUInt32:
size = sizeof(std::uint32_t);
break;
case DataType::kInt64:
size = sizeof(std::int64_t);
break;
case DataType::kUInt64:
size = sizeof(std::uint64_t);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return size;
}
/**
* @brief A Federated Learning communicator class that handles collective communication.
*/

View File

@@ -7,7 +7,6 @@
#include <grpcpp/server_builder.h>
#include <xgboost/logging.h>
#include <fstream>
#include <sstream>
#include "../../src/common/io.h"
@@ -15,184 +14,20 @@
namespace xgboost {
namespace federated {
class AllgatherFunctor {
public:
std::string const name{"Allgather"};
explicit AllgatherFunctor(int const world_size) : world_size_{world_size} {}
void operator()(AllgatherRequest const* request, std::string& buffer) const {
auto const rank = request->rank();
auto const& send_buffer = request->send_buffer();
auto const send_size = send_buffer.size();
// Resize the buffer if this is the first request.
if (buffer.size() != send_size * world_size_) {
buffer.resize(send_size * world_size_);
}
// Splice the send_buffer into the common buffer.
buffer.replace(rank * send_size, send_size, send_buffer);
}
private:
int const world_size_;
};
class AllreduceFunctor {
public:
std::string const name{"Allreduce"};
void operator()(AllreduceRequest const* request, std::string& buffer) const {
if (buffer.empty()) {
// Copy the send_buffer if this is the first request.
buffer = request->send_buffer();
} else {
// Apply the reduce_operation to the send_buffer and the common buffer.
Accumulate(buffer, request->send_buffer(), request->data_type(), request->reduce_operation());
}
}
private:
template <class T>
void Accumulate(T* buffer, T const* input, std::size_t n,
ReduceOperation reduce_operation) const {
switch (reduce_operation) {
case ReduceOperation::MAX:
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::max(a, b); });
break;
case ReduceOperation::MIN:
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::min(a, b); });
break;
case ReduceOperation::SUM:
std::transform(buffer, buffer + n, input, buffer, std::plus<T>());
break;
default:
throw std::invalid_argument("Invalid reduce operation");
}
}
void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
ReduceOperation reduce_operation) const {
switch (data_type) {
case DataType::INT8:
Accumulate(reinterpret_cast<std::int8_t*>(&buffer[0]),
reinterpret_cast<std::int8_t const*>(input.data()), buffer.size(),
reduce_operation);
break;
case DataType::UINT8:
Accumulate(reinterpret_cast<std::uint8_t*>(&buffer[0]),
reinterpret_cast<std::uint8_t const*>(input.data()), buffer.size(),
reduce_operation);
break;
case DataType::INT32:
Accumulate(reinterpret_cast<std::int32_t*>(&buffer[0]),
reinterpret_cast<std::int32_t const*>(input.data()),
buffer.size() / sizeof(std::uint32_t), reduce_operation);
break;
case DataType::UINT32:
Accumulate(reinterpret_cast<std::uint32_t*>(&buffer[0]),
reinterpret_cast<std::uint32_t const*>(input.data()),
buffer.size() / sizeof(std::uint32_t), reduce_operation);
break;
case DataType::INT64:
Accumulate(reinterpret_cast<std::int64_t*>(&buffer[0]),
reinterpret_cast<std::int64_t const*>(input.data()),
buffer.size() / sizeof(std::int64_t), reduce_operation);
break;
case DataType::UINT64:
Accumulate(reinterpret_cast<std::uint64_t*>(&buffer[0]),
reinterpret_cast<std::uint64_t const*>(input.data()),
buffer.size() / sizeof(std::uint64_t), reduce_operation);
break;
case DataType::FLOAT:
Accumulate(reinterpret_cast<float*>(&buffer[0]),
reinterpret_cast<float const*>(input.data()), buffer.size() / sizeof(float),
reduce_operation);
break;
case DataType::DOUBLE:
Accumulate(reinterpret_cast<double*>(&buffer[0]),
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
reduce_operation);
break;
default:
throw std::invalid_argument("Invalid data type");
}
}
};
class BroadcastFunctor {
public:
std::string const name{"Broadcast"};
void operator()(BroadcastRequest const* request, std::string& buffer) const {
if (request->rank() == request->root()) {
// Copy the send_buffer if this is the root.
buffer = request->send_buffer();
}
}
};
grpc::Status FederatedService::Allgather(grpc::ServerContext* context,
AllgatherRequest const* request, AllgatherReply* reply) {
return Handle(request, reply, AllgatherFunctor{world_size_});
}
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
AllreduceRequest const* request, AllreduceReply* reply) {
return Handle(request, reply, AllreduceFunctor{});
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
static_cast<xgboost::collective::DataType>(request->data_type()),
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
return grpc::Status::OK;
}
grpc::Status FederatedService::Broadcast(grpc::ServerContext* context,
BroadcastRequest const* request, BroadcastReply* reply) {
return Handle(request, reply, BroadcastFunctor{});
}
template <class Request, class Reply, class RequestFunctor>
grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
RequestFunctor const& functor) {
// Pass through if there is only 1 client.
if (world_size_ == 1) {
reply->set_receive_buffer(request->send_buffer());
return grpc::Status::OK;
}
std::unique_lock<std::mutex> lock(mutex_);
auto const sequence_number = request->sequence_number();
auto const rank = request->rank();
LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number";
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
LOG(INFO) << functor.name << " rank " << rank << ": handling request";
functor(request, buffer_);
received_++;
if (received_ == world_size_) {
LOG(INFO) << functor.name << " rank " << rank << ": all requests received";
reply->set_receive_buffer(buffer_);
sent_++;
lock.unlock();
cv_.notify_all();
return grpc::Status::OK;
}
LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients";
cv_.wait(lock, [this] { return received_ == world_size_; });
LOG(INFO) << functor.name << " rank " << rank << ": sending reply";
reply->set_receive_buffer(buffer_);
sent_++;
if (sent_ == world_size_) {
LOG(INFO) << functor.name << " rank " << rank << ": all replies sent";
sent_ = 0;
received_ = 0;
buffer_.clear();
sequence_number_++;
lock.unlock();
cv_.notify_all();
}
handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
request->root());
return grpc::Status::OK;
}

View File

@@ -5,18 +5,14 @@
#include <federated.grpc.pb.h>
#include <condition_variable>
#include <mutex>
#include "../../src/collective/in_memory_handler.h"
namespace xgboost {
namespace federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(int const world_size) : world_size_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
explicit FederatedService(int const world_size) : handler_{world_size} {}
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;
@@ -25,16 +21,7 @@ class FederatedService final : public Federated::Service {
BroadcastReply* reply) override;
private:
template <class Request, class Reply, class RequestFunctor>
grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor);
int const world_size_;
int received_{};
int sent_{};
std::string buffer_{};
uint64_t sequence_number_{};
mutable std::mutex mutex_;
mutable std::condition_variable cv_;
xgboost::collective::InMemoryHandler handler_;
};
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,