Federated learning plugin for xgboost: * A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers. * A Rabit engine for the federated environment. * Integration test to simulate federated learning. Additional followups are needed to address GPU support, better security, and privacy, etc.
45 lines
1.3 KiB
C++
45 lines
1.3 KiB
C++
/*!
|
|
* Copyright 2022 XGBoost contributors
|
|
*/
|
|
#pragma once
|
|
|
|
#include <federated.grpc.pb.h>
|
|
|
|
#include <condition_variable>
|
|
#include <mutex>
|
|
|
|
namespace xgboost {
|
|
namespace federated {
|
|
|
|
class FederatedService final : public Federated::Service {
|
|
public:
|
|
explicit FederatedService(int const world_size) : world_size_{world_size} {}
|
|
|
|
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
|
AllgatherReply* reply) override;
|
|
|
|
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
|
AllreduceReply* reply) override;
|
|
|
|
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
|
|
BroadcastReply* reply) override;
|
|
|
|
private:
|
|
template <class Request, class Reply, class RequestFunctor>
|
|
grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor);
|
|
|
|
int const world_size_;
|
|
int received_{};
|
|
int sent_{};
|
|
std::string buffer_{};
|
|
uint64_t sequence_number_{};
|
|
mutable std::mutex mutex_;
|
|
mutable std::condition_variable cv_;
|
|
};
|
|
|
|
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
|
char const* client_cert_file);
|
|
|
|
} // namespace federated
|
|
} // namespace xgboost
|