Initial support for federated learning (#7831)

Federated learning plugin for xgboost:
* A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers.
* A Rabit engine for the federated environment.
* Integration test to simulate federated learning.

Additional followups are needed to address GPU support, better security, and privacy, etc.
This commit is contained in:
Rong Ou
2022-05-05 06:49:22 -07:00
committed by GitHub
parent 46e0bce212
commit 14ef38b834
16 changed files with 1087 additions and 1 deletions

View File

@@ -0,0 +1,27 @@
# gRPC needs to be installed first. See README.md.
find_package(Protobuf REQUIRED)
find_package(gRPC REQUIRED)
find_package(Threads)
# Generated code from the protobuf definition.
add_library(federated_proto federated.proto)
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET federated_proto LANGUAGE cpp)
protobuf_generate(
TARGET federated_proto
LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}")
# Wrapper for the gRPC client.
add_library(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)
# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc)
target_link_libraries(objxgboost PRIVATE federated_client)
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)

View File

@@ -0,0 +1,35 @@
XGBoost Plugin for Federated Learning
=====================================
This folder contains the plugin for federated learning. Follow these steps to build and test it.
Install gRPC
------------
```shell
sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build
git clone -b v1.45.2 https://github.com/grpc/grpc
cd grpc
git submodule update --init
cmake -S . -B build -GNinja -DABSL_PROPAGATE_CXX_STD=ON
cmake --build build --target install
```
Build the Plugin
----------------
```shell
# Under xgboost source tree.
mkdir build
cd build
cmake .. -GNinja -DPLUGIN_FEDERATED=ON
ninja
cd ../python-package
pip install -e . # or equivalently python setup.py develop
```
Test Federated XGBoost
----------------------
```shell
# Under xgboost source tree.
cd tests/distributed
./runtests-federated.sh
```

View File

@@ -0,0 +1,274 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <cstdio>
#include <fstream>
#include <sstream>
#include "federated_client.h"
#include "rabit/internal/engine.h"
#include "rabit/internal/utils.h"
namespace MPI { // NOLINT
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
size_t type_size;
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
} // namespace MPI
namespace rabit {
namespace engine {
/*! \brief implementation of engine using federated learning */
class FederatedEngine : public IEngine {
public:
void Init(int argc, char *argv[]) {
// Parse environment variables first.
for (auto const &env_var : env_vars_) {
char const *value = getenv(env_var.c_str());
if (value != nullptr) {
SetParam(env_var, value);
}
}
// Command line argument overrides.
for (int i = 0; i < argc; ++i) {
std::string const key_value = argv[i];
auto const delimiter = key_value.find('=');
if (delimiter != std::string::npos) {
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1));
}
}
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
server_address_.c_str(), world_size_, rank_);
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
}
void Finalize() { client_.reset(); }
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end,
size_t size_prev_slice) override {
throw std::logic_error("FederatedEngine:: Allgather is not supported");
}
std::string Allgather(void *sendbuf, size_t total_size) {
std::string const send_buffer(reinterpret_cast<char *>(sendbuf), total_size);
return client_->Allgather(send_buffer);
}
void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer,
PreprocFunction prepare_fun, void *prepare_arg) override {
throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead");
}
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
receive_buffer.copy(buffer, size);
}
int GetRingPrevRank() const override {
throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported");
}
void Broadcast(void *sendrecvbuf, size_t size, int root) override {
if (world_size_ == 1) return;
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Broadcast(send_buffer, root);
if (rank_ != root) {
receive_buffer.copy(buffer, size);
}
}
int LoadCheckPoint(Serializable *global_model, Serializable *local_model = nullptr) override {
return 0;
}
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number_ += 1;
}
void LazyCheckPoint(const Serializable *global_model) override { version_number_ += 1; }
int VersionNumber() const override { return version_number_; }
/*! \brief get rank of current node */
int GetRank() const override { return rank_; }
/*! \brief get total number of */
int GetWorldSize() const override { return world_size_; }
/*! \brief whether it is distributed */
bool IsDistributed() const override { return true; }
/*! \brief get the host name of current node */
std::string GetHost() const override { return "rank" + std::to_string(rank_); }
void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
}
private:
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
switch (data_type) {
case mpi::kChar:
return xgboost::federated::CHAR;
case mpi::kUChar:
return xgboost::federated::UCHAR;
case mpi::kInt:
return xgboost::federated::INT;
case mpi::kUInt:
return xgboost::federated::UINT;
case mpi::kLong:
return xgboost::federated::LONG;
case mpi::kULong:
return xgboost::federated::ULONG;
case mpi::kFloat:
return xgboost::federated::FLOAT;
case mpi::kDouble:
return xgboost::federated::DOUBLE;
case mpi::kLongLong:
return xgboost::federated::LONGLONG;
case mpi::kULongLong:
return xgboost::federated::ULONGLONG;
}
utils::Error("unknown mpi::DataType");
return xgboost::federated::CHAR;
}
/** @brief Transform mpi::OpType to enum to MPI OP */
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
switch (op_type) {
case mpi::kMax:
return xgboost::federated::MAX;
case mpi::kMin:
return xgboost::federated::MIN;
case mpi::kSum:
return xgboost::federated::SUM;
case mpi::kBitwiseOR:
utils::Error("Bitwise OR is not supported");
return xgboost::federated::MAX;
}
utils::Error("unknown mpi::OpType");
return xgboost::federated::MAX;
}
void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) {
world_size_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) {
rank_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) {
server_cert_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) {
client_key_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) {
client_cert_ = ReadFile(val);
}
}
static std::string ReadFile(std::string const &path) {
auto stream = std::ifstream(path.data());
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}
// clang-format off
std::vector<std::string> const env_vars_{
"FEDERATED_SERVER_ADDRESS",
"FEDERATED_WORLD_SIZE",
"FEDERATED_RANK",
"FEDERATED_SERVER_CERT",
"FEDERATED_CLIENT_KEY",
"FEDERATED_CLIENT_CERT" };
// clang-format on
std::string server_address_{"localhost:9091"};
int world_size_{1};
int rank_{0};
std::string server_cert_{};
std::string client_key_{};
std::string client_cert_{};
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
int version_number_{0};
};
// Singleton federated engine.
FederatedEngine engine; // NOLINT(cert-err58-cpp)
/*! \brief initialize the synchronization module */
bool Init(int argc, char *argv[]) {
try {
engine.Init(argc, argv);
return true;
} catch (std::exception const &e) {
fprintf(stderr, " failed in federated Init %s\n", e.what());
return false;
}
}
/*! \brief finalize synchronization module */
bool Finalize() {
try {
engine.Finalize();
return true;
} catch (const std::exception &e) {
fprintf(stderr, "failed in federated shutdown %s\n", e.what());
return false;
}
}
/*! \brief singleton method to get engine */
IEngine *GetEngine() { return &engine; }
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red,
mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (engine.GetWorldSize() == 1) return;
engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op);
}
ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast<int>(dtype.type_size); }
void ReduceHandle::Init(IEngine::ReduceFunction redfunc,
__attribute__((unused)) size_t type_nbytes) {
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}
void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun, void *prepare_arg) {
utils::Assert(redfunc_ != nullptr, "must initialize handle to call AllReduce");
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (engine.GetWorldSize() == 1) return;
// Gather all the buffers and call the reduce function locally.
auto const buffer_size = type_nbytes * count;
auto const gathered = engine.Allgather(sendrecvbuf, buffer_size);
auto const *data = gathered.data();
for (int i = 0; i < engine.GetWorldSize(); i++) {
if (i != engine.GetRank()) {
redfunc_(data + buffer_size * i, sendrecvbuf, static_cast<int>(count),
MPI::Datatype(type_nbytes));
}
}
}
} // namespace engine
} // namespace rabit

View File

@@ -0,0 +1,68 @@
/*!
* Copyright 2022 XGBoost contributors
*/
syntax = "proto3";
package xgboost.federated;
service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}
enum DataType {
CHAR = 0;
UCHAR = 1;
INT = 2;
UINT = 3;
LONG = 4;
ULONG = 5;
FLOAT = 6;
DOUBLE = 7;
LONGLONG = 8;
ULONGLONG = 9;
}
enum ReduceOperation {
MAX = 0;
MIN = 1;
SUM = 2;
}
message AllgatherRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherReply {
bytes receive_buffer = 1;
}
message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
DataType data_type = 4;
ReduceOperation reduce_operation = 5;
}
message AllreduceReply {
bytes receive_buffer = 1;
}
message BroadcastRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
// The root rank to broadcast from.
int32 root = 4;
}
message BroadcastReply {
bytes receive_buffer = 1;
}

View File

@@ -0,0 +1,104 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <federated.pb.h>
#include <grpcpp/grpcpp.h>
#include <cstdio>
#include <cstdlib>
#include <string>
namespace xgboost {
namespace federated {
/**
* @brief A wrapper around the gRPC client.
*/
class FederatedClient {
public:
FederatedClient(std::string const &server_address, int rank, std::string const &server_cert,
std::string const &client_key, std::string const &client_cert)
: stub_{[&] {
grpc::SslCredentialsOptions options;
options.pem_root_certs = server_cert;
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
return Federated::NewStub(
grpc::CreateChannel(server_address, grpc::SslCredentials(options)));
}()},
rank_{rank} {}
/** @brief Insecure client for testing only. */
FederatedClient(std::string const &server_address, int rank)
: stub_{Federated::NewStub(
grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))},
rank_{rank} {}
std::string Allgather(std::string const &send_buffer) {
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
AllgatherReply reply;
grpc::ClientContext context;
grpc::Status status = stub_->Allgather(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allgather RPC failed");
}
}
std::string Allreduce(std::string const &send_buffer, DataType data_type,
ReduceOperation reduce_operation) {
AllreduceRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
request.set_data_type(data_type);
request.set_reduce_operation(reduce_operation);
AllreduceReply reply;
grpc::ClientContext context;
grpc::Status status = stub_->Allreduce(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allreduce RPC failed");
}
}
std::string Broadcast(std::string const &send_buffer, int root) {
BroadcastRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
request.set_root(root);
BroadcastReply reply;
grpc::ClientContext context;
grpc::Status status = stub_->Broadcast(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Broadcast RPC failed");
}
}
private:
std::unique_ptr<Federated::Stub> const stub_;
int const rank_;
uint64_t sequence_number_{};
};
} // namespace federated
} // namespace xgboost

View File

@@ -0,0 +1,234 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "federated_server.h"
#include <grpcpp/grpcpp.h>
#include <grpcpp/server_builder.h>
#include <xgboost/logging.h>
#include <fstream>
#include <sstream>
namespace xgboost {
namespace federated {
class AllgatherFunctor {
public:
std::string const name{"Allgather"};
explicit AllgatherFunctor(int const world_size) : world_size_{world_size} {}
void operator()(AllgatherRequest const* request, std::string& buffer) const {
auto const rank = request->rank();
auto const& send_buffer = request->send_buffer();
auto const send_size = send_buffer.size();
// Resize the buffer if this is the first request.
if (buffer.size() != send_size * world_size_) {
buffer.resize(send_size * world_size_);
}
// Splice the send_buffer into the common buffer.
buffer.replace(rank * send_size, send_size, send_buffer);
}
private:
int const world_size_;
};
class AllreduceFunctor {
public:
std::string const name{"Allreduce"};
void operator()(AllreduceRequest const* request, std::string& buffer) const {
if (buffer.empty()) {
// Copy the send_buffer if this is the first request.
buffer = request->send_buffer();
} else {
// Apply the reduce_operation to the send_buffer and the common buffer.
Accumulate(buffer, request->send_buffer(), request->data_type(), request->reduce_operation());
}
}
private:
template <class T>
void Accumulate(T* buffer, T const* input, std::size_t n,
ReduceOperation reduce_operation) const {
switch (reduce_operation) {
case ReduceOperation::MAX:
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::max(a, b); });
break;
case ReduceOperation::MIN:
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::min(a, b); });
break;
case ReduceOperation::SUM:
std::transform(buffer, buffer + n, input, buffer, std::plus<T>());
break;
default:
throw std::invalid_argument("Invalid reduce operation");
}
}
void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
ReduceOperation reduce_operation) const {
switch (data_type) {
case DataType::CHAR:
Accumulate(&buffer[0], reinterpret_cast<char const*>(input.data()), buffer.size(),
reduce_operation);
break;
case DataType::UCHAR:
Accumulate(reinterpret_cast<unsigned char*>(&buffer[0]),
reinterpret_cast<unsigned char const*>(input.data()), buffer.size(),
reduce_operation);
break;
case DataType::INT:
Accumulate(reinterpret_cast<int*>(&buffer[0]), reinterpret_cast<int const*>(input.data()),
buffer.size() / sizeof(int), reduce_operation);
break;
case DataType::UINT:
Accumulate(reinterpret_cast<unsigned int*>(&buffer[0]),
reinterpret_cast<unsigned int const*>(input.data()),
buffer.size() / sizeof(unsigned int), reduce_operation);
break;
case DataType::LONG:
Accumulate(reinterpret_cast<long*>(&buffer[0]), reinterpret_cast<long const*>(input.data()),
buffer.size() / sizeof(long), reduce_operation);
break;
case DataType::ULONG:
Accumulate(reinterpret_cast<unsigned long*>(&buffer[0]),
reinterpret_cast<unsigned long const*>(input.data()),
buffer.size() / sizeof(unsigned long), reduce_operation);
break;
case DataType::FLOAT:
Accumulate(reinterpret_cast<float*>(&buffer[0]),
reinterpret_cast<float const*>(input.data()), buffer.size() / sizeof(float),
reduce_operation);
break;
case DataType::DOUBLE:
Accumulate(reinterpret_cast<double*>(&buffer[0]),
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
reduce_operation);
break;
case DataType::LONGLONG:
Accumulate(reinterpret_cast<long long*>(&buffer[0]),
reinterpret_cast<long long const*>(input.data()),
buffer.size() / sizeof(long long), reduce_operation);
break;
case DataType::ULONGLONG:
Accumulate(reinterpret_cast<unsigned long long*>(&buffer[0]),
reinterpret_cast<unsigned long long const*>(input.data()),
buffer.size() / sizeof(unsigned long long), reduce_operation);
break;
default:
throw std::invalid_argument("Invalid data type");
}
}
};
class BroadcastFunctor {
public:
std::string const name{"Broadcast"};
void operator()(BroadcastRequest const* request, std::string& buffer) const {
if (request->rank() == request->root()) {
// Copy the send_buffer if this is the root.
buffer = request->send_buffer();
}
}
};
grpc::Status FederatedService::Allgather(grpc::ServerContext* context,
AllgatherRequest const* request, AllgatherReply* reply) {
return Handle(request, reply, AllgatherFunctor{world_size_});
}
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
AllreduceRequest const* request, AllreduceReply* reply) {
return Handle(request, reply, AllreduceFunctor{});
}
grpc::Status FederatedService::Broadcast(grpc::ServerContext* context,
BroadcastRequest const* request, BroadcastReply* reply) {
return Handle(request, reply, BroadcastFunctor{});
}
template <class Request, class Reply, class RequestFunctor>
grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
RequestFunctor const& functor) {
// Pass through if there is only 1 client.
if (world_size_ == 1) {
reply->set_receive_buffer(request->send_buffer());
return grpc::Status::OK;
}
std::unique_lock<std::mutex> lock(mutex_);
auto const sequence_number = request->sequence_number();
auto const rank = request->rank();
LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number";
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
LOG(INFO) << functor.name << " rank " << rank << ": handling request";
functor(request, buffer_);
received_++;
if (received_ == world_size_) {
LOG(INFO) << functor.name << " rank " << rank << ": all requests received";
reply->set_receive_buffer(buffer_);
sent_++;
lock.unlock();
cv_.notify_all();
return grpc::Status::OK;
}
LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients";
cv_.wait(lock, [this] { return received_ == world_size_; });
LOG(INFO) << functor.name << " rank " << rank << ": sending reply";
reply->set_receive_buffer(buffer_);
sent_++;
if (sent_ == world_size_) {
LOG(INFO) << functor.name << " rank " << rank << ": all replies sent";
sent_ = 0;
received_ = 0;
buffer_.clear();
sequence_number_++;
lock.unlock();
cv_.notify_all();
}
return grpc::Status::OK;
}
std::string ReadFile(char const* path) {
auto stream = std::ifstream(path);
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
grpc::ServerBuilder builder;
auto options =
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = ReadFile(client_cert_file);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = ReadFile(server_key_file);
key.cert_chain = ReadFile(server_cert_file);
options.pem_key_cert_pairs.push_back(key);
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
<< world_size;
server->Wait();
}
} // namespace federated
} // namespace xgboost

View File

@@ -0,0 +1,44 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <condition_variable>
#include <mutex>
namespace xgboost {
namespace federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(int const world_size) : world_size_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
BroadcastReply* reply) override;
private:
template <class Request, class Reply, class RequestFunctor>
grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor);
int const world_size_;
int received_{};
int sent_{};
std::string buffer_{};
uint64_t sequence_number_{};
mutable std::mutex mutex_;
mutable std::condition_variable cv_;
};
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file);
} // namespace federated
} // namespace xgboost