Improve allgather functions (#9649)
This commit is contained in:
@@ -7,6 +7,7 @@ package xgboost.federated;
|
||||
|
||||
service Federated {
|
||||
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
|
||||
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
||||
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
||||
}
|
||||
@@ -42,6 +43,17 @@ message AllgatherReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllgatherVRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherVReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllreduceRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
|
||||
@@ -44,11 +44,11 @@ class FederatedClient {
|
||||
}()},
|
||||
rank_{rank} {}
|
||||
|
||||
std::string Allgather(std::string const &send_buffer) {
|
||||
std::string Allgather(std::string_view send_buffer) {
|
||||
AllgatherRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer);
|
||||
request.set_send_buffer(send_buffer.data(), send_buffer.size());
|
||||
|
||||
AllgatherReply reply;
|
||||
grpc::ClientContext context;
|
||||
@@ -63,6 +63,25 @@ class FederatedClient {
|
||||
}
|
||||
}
|
||||
|
||||
std::string AllgatherV(std::string_view send_buffer) {
|
||||
AllgatherVRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer.data(), send_buffer.size());
|
||||
|
||||
AllgatherVReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->AllgatherV(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("AllgatherV RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::string Allreduce(std::string const &send_buffer, DataType data_type,
|
||||
ReduceOperation reduce_operation) {
|
||||
AllreduceRequest request;
|
||||
|
||||
@@ -125,14 +125,19 @@ class FederatedCommunicator : public Communicator {
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
|
||||
/**
|
||||
* \brief Perform in-place allgather.
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param size Number of bytes to be gathered.
|
||||
* \brief Perform allgather.
|
||||
* \param input Buffer for sending data.
|
||||
*/
|
||||
void AllGather(void *send_receive_buffer, std::size_t size) override {
|
||||
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
|
||||
auto const received = client_->Allgather(send_buffer);
|
||||
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
|
||||
std::string AllGather(std::string_view input) override {
|
||||
return client_->Allgather(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perform variable-length allgather.
|
||||
* \param input Buffer for sending data.
|
||||
*/
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
return client_->AllgatherV(input);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -19,6 +19,13 @@ grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) {
|
||||
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) {
|
||||
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||
@@ -36,8 +43,8 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||
char const* client_cert_file) {
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
|
||||
@@ -59,7 +66,7 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
void RunInsecureServer(int port, int world_size) {
|
||||
void RunInsecureServer(int port, std::size_t world_size) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
|
||||
|
||||
@@ -12,11 +12,14 @@ namespace federated {
|
||||
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(int const world_size) : handler_{world_size} {}
|
||||
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) override;
|
||||
|
||||
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) override;
|
||||
|
||||
@@ -27,10 +30,10 @@ class FederatedService final : public Federated::Service {
|
||||
xgboost::collective::InMemoryHandler handler_;
|
||||
};
|
||||
|
||||
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||
char const* client_cert_file);
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file);
|
||||
|
||||
void RunInsecureServer(int port, int world_size);
|
||||
void RunInsecureServer(int port, std::size_t world_size);
|
||||
|
||||
} // namespace federated
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user