Add Allgather to collective communicator (#8765)

* Add Allgather to collective communicator
This commit is contained in:
Rong Ou
2023-02-08 19:31:22 -08:00
committed by GitHub
parent 48cefa012e
commit cbf98cb9c6
14 changed files with 187 additions and 4 deletions

View File

@@ -6,6 +6,7 @@ syntax = "proto3";
package xgboost.federated;
service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}
@@ -30,6 +31,17 @@ enum ReduceOperation {
BITWISE_XOR = 5;
}
message AllgatherRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherReply {
bytes receive_buffer = 1;
}
message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;

View File

@@ -46,6 +46,25 @@ class FederatedClient {
}()},
rank_{rank} {}
std::string Allgather(std::string const &send_buffer) {
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
AllgatherReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub_->Allgather(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allgather RPC failed");
}
}
std::string Allreduce(std::string const &send_buffer, DataType data_type,
ReduceOperation reduce_operation) {
AllreduceRequest request;

View File

@@ -126,6 +126,17 @@ class FederatedCommunicator : public Communicator {
*/
bool IsFederated() const override { return true; }
/**
* \brief Perform in-place allgather.
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param size Number of bytes to be gathered.
*/
void AllGather(void *send_receive_buffer, std::size_t size) override {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
auto const received = client_->Allgather(send_buffer);
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
}
/**
* \brief Perform in-place allreduce.
* \param send_receive_buffer Buffer for both sending and receiving data.

View File

@@ -14,6 +14,13 @@
namespace xgboost {
namespace federated {
grpc::Status FederatedService::Allgather(grpc::ServerContext* context,
AllgatherRequest const* request, AllgatherReply* reply) {
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
AllreduceRequest const* request, AllreduceReply* reply) {
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),

View File

@@ -14,6 +14,9 @@ class FederatedService final : public Federated::Service {
public:
explicit FederatedService(int const world_size) : handler_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;