xgboost/plugin/federated/federated_server.h
Rong Ou 14ef38b834
Initial support for federated learning (#7831)
Federated learning plugin for xgboost:
* A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers.
* A Rabit engine for the federated environment.
* Integration test to simulate federated learning.

Additional followups are needed to address GPU support, better security, and privacy, etc.
2022-05-05 21:49:22 +08:00

45 lines
1.3 KiB
C++

/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <condition_variable>
#include <mutex>
namespace xgboost {
namespace federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(int const world_size) : world_size_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
BroadcastReply* reply) override;
private:
template <class Request, class Reply, class RequestFunctor>
grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor);
int const world_size_;
int received_{};
int sent_{};
std::string buffer_{};
uint64_t sequence_number_{};
mutable std::mutex mutex_;
mutable std::condition_variable cv_;
};
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file);
} // namespace federated
} // namespace xgboost