Initial support for federated learning (#7831)

Federated learning plugin for xgboost:
* A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers.
* A Rabit engine for the federated environment.
* Integration test to simulate federated learning.

Additional followups are needed to address GPU support, better security, and privacy, etc.
This commit is contained in:
Rong Ou 2022-05-05 06:49:22 -07:00 committed by GitHub
parent 46e0bce212
commit 14ef38b834
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1087 additions and 1 deletions

View File

@ -66,6 +66,7 @@ address, leak, undefined and thread.")
## Plugins
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF)
option(PLUGIN_FEDERATED "Build with Federated Learning" OFF)
## TODO: 1. Add check if DPC++ compiler is used for building
option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF)
option(ADD_PKGCONFIG "Add xgboost.pc into system." ON)

View File

@ -40,3 +40,8 @@ if (PLUGIN_UPDATER_ONEAPI)
# Add all objects of oneapi_plugin to objxgboost
target_sources(objxgboost INTERFACE $<TARGET_OBJECTS:oneapi_plugin>)
endif (PLUGIN_UPDATER_ONEAPI)
# Add the Federate Learning plugin if enabled.
if (PLUGIN_FEDERATED)
add_subdirectory(federated)
endif (PLUGIN_FEDERATED)

View File

@ -0,0 +1,27 @@
# gRPC needs to be installed first. See README.md.
find_package(Protobuf REQUIRED)
find_package(gRPC REQUIRED)
find_package(Threads)
# Generated code from the protobuf definition.
add_library(federated_proto federated.proto)
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET federated_proto LANGUAGE cpp)
protobuf_generate(
TARGET federated_proto
LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}")
# Wrapper for the gRPC client.
add_library(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)
# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc)
target_link_libraries(objxgboost PRIVATE federated_client)
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)

View File

@ -0,0 +1,35 @@
XGBoost Plugin for Federated Learning
=====================================
This folder contains the plugin for federated learning. Follow these steps to build and test it.
Install gRPC
------------
```shell
sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build
git clone -b v1.45.2 https://github.com/grpc/grpc
cd grpc
git submodule update --init
cmake -S . -B build -GNinja -DABSL_PROPAGATE_CXX_STD=ON
cmake --build build --target install
```
Build the Plugin
----------------
```shell
# Under xgboost source tree.
mkdir build
cd build
cmake .. -GNinja -DPLUGIN_FEDERATED=ON
ninja
cd ../python-package
pip install -e . # or equivalently python setup.py develop
```
Test Federated XGBoost
----------------------
```shell
# Under xgboost source tree.
cd tests/distributed
./runtests-federated.sh
```

View File

@ -0,0 +1,274 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <cstdio>
#include <fstream>
#include <sstream>
#include "federated_client.h"
#include "rabit/internal/engine.h"
#include "rabit/internal/utils.h"
namespace MPI { // NOLINT
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
size_t type_size;
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
} // namespace MPI
namespace rabit {
namespace engine {
/*! \brief implementation of engine using federated learning */
class FederatedEngine : public IEngine {
public:
void Init(int argc, char *argv[]) {
// Parse environment variables first.
for (auto const &env_var : env_vars_) {
char const *value = getenv(env_var.c_str());
if (value != nullptr) {
SetParam(env_var, value);
}
}
// Command line argument overrides.
for (int i = 0; i < argc; ++i) {
std::string const key_value = argv[i];
auto const delimiter = key_value.find('=');
if (delimiter != std::string::npos) {
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1));
}
}
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
server_address_.c_str(), world_size_, rank_);
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
}
void Finalize() { client_.reset(); }
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end,
size_t size_prev_slice) override {
throw std::logic_error("FederatedEngine:: Allgather is not supported");
}
std::string Allgather(void *sendbuf, size_t total_size) {
std::string const send_buffer(reinterpret_cast<char *>(sendbuf), total_size);
return client_->Allgather(send_buffer);
}
void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer,
PreprocFunction prepare_fun, void *prepare_arg) override {
throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead");
}
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
receive_buffer.copy(buffer, size);
}
int GetRingPrevRank() const override {
throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported");
}
void Broadcast(void *sendrecvbuf, size_t size, int root) override {
if (world_size_ == 1) return;
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Broadcast(send_buffer, root);
if (rank_ != root) {
receive_buffer.copy(buffer, size);
}
}
int LoadCheckPoint(Serializable *global_model, Serializable *local_model = nullptr) override {
return 0;
}
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number_ += 1;
}
void LazyCheckPoint(const Serializable *global_model) override { version_number_ += 1; }
int VersionNumber() const override { return version_number_; }
/*! \brief get rank of current node */
int GetRank() const override { return rank_; }
/*! \brief get total number of */
int GetWorldSize() const override { return world_size_; }
/*! \brief whether it is distributed */
bool IsDistributed() const override { return true; }
/*! \brief get the host name of current node */
std::string GetHost() const override { return "rank" + std::to_string(rank_); }
void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
}
private:
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
switch (data_type) {
case mpi::kChar:
return xgboost::federated::CHAR;
case mpi::kUChar:
return xgboost::federated::UCHAR;
case mpi::kInt:
return xgboost::federated::INT;
case mpi::kUInt:
return xgboost::federated::UINT;
case mpi::kLong:
return xgboost::federated::LONG;
case mpi::kULong:
return xgboost::federated::ULONG;
case mpi::kFloat:
return xgboost::federated::FLOAT;
case mpi::kDouble:
return xgboost::federated::DOUBLE;
case mpi::kLongLong:
return xgboost::federated::LONGLONG;
case mpi::kULongLong:
return xgboost::federated::ULONGLONG;
}
utils::Error("unknown mpi::DataType");
return xgboost::federated::CHAR;
}
/** @brief Transform mpi::OpType to enum to MPI OP */
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
switch (op_type) {
case mpi::kMax:
return xgboost::federated::MAX;
case mpi::kMin:
return xgboost::federated::MIN;
case mpi::kSum:
return xgboost::federated::SUM;
case mpi::kBitwiseOR:
utils::Error("Bitwise OR is not supported");
return xgboost::federated::MAX;
}
utils::Error("unknown mpi::OpType");
return xgboost::federated::MAX;
}
void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) {
world_size_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) {
rank_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) {
server_cert_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) {
client_key_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) {
client_cert_ = ReadFile(val);
}
}
static std::string ReadFile(std::string const &path) {
auto stream = std::ifstream(path.data());
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}
// clang-format off
std::vector<std::string> const env_vars_{
"FEDERATED_SERVER_ADDRESS",
"FEDERATED_WORLD_SIZE",
"FEDERATED_RANK",
"FEDERATED_SERVER_CERT",
"FEDERATED_CLIENT_KEY",
"FEDERATED_CLIENT_CERT" };
// clang-format on
std::string server_address_{"localhost:9091"};
int world_size_{1};
int rank_{0};
std::string server_cert_{};
std::string client_key_{};
std::string client_cert_{};
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
int version_number_{0};
};
// Singleton federated engine.
FederatedEngine engine; // NOLINT(cert-err58-cpp)
/*! \brief initialize the synchronization module */
bool Init(int argc, char *argv[]) {
try {
engine.Init(argc, argv);
return true;
} catch (std::exception const &e) {
fprintf(stderr, " failed in federated Init %s\n", e.what());
return false;
}
}
/*! \brief finalize synchronization module */
bool Finalize() {
try {
engine.Finalize();
return true;
} catch (const std::exception &e) {
fprintf(stderr, "failed in federated shutdown %s\n", e.what());
return false;
}
}
/*! \brief singleton method to get engine */
IEngine *GetEngine() { return &engine; }
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red,
mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (engine.GetWorldSize() == 1) return;
engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op);
}
ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast<int>(dtype.type_size); }
void ReduceHandle::Init(IEngine::ReduceFunction redfunc,
__attribute__((unused)) size_t type_nbytes) {
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}
void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun, void *prepare_arg) {
utils::Assert(redfunc_ != nullptr, "must initialize handle to call AllReduce");
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (engine.GetWorldSize() == 1) return;
// Gather all the buffers and call the reduce function locally.
auto const buffer_size = type_nbytes * count;
auto const gathered = engine.Allgather(sendrecvbuf, buffer_size);
auto const *data = gathered.data();
for (int i = 0; i < engine.GetWorldSize(); i++) {
if (i != engine.GetRank()) {
redfunc_(data + buffer_size * i, sendrecvbuf, static_cast<int>(count),
MPI::Datatype(type_nbytes));
}
}
}
} // namespace engine
} // namespace rabit

View File

@ -0,0 +1,68 @@
/*!
* Copyright 2022 XGBoost contributors
*/
syntax = "proto3";
package xgboost.federated;
service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}
enum DataType {
CHAR = 0;
UCHAR = 1;
INT = 2;
UINT = 3;
LONG = 4;
ULONG = 5;
FLOAT = 6;
DOUBLE = 7;
LONGLONG = 8;
ULONGLONG = 9;
}
enum ReduceOperation {
MAX = 0;
MIN = 1;
SUM = 2;
}
message AllgatherRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherReply {
bytes receive_buffer = 1;
}
message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
DataType data_type = 4;
ReduceOperation reduce_operation = 5;
}
message AllreduceReply {
bytes receive_buffer = 1;
}
message BroadcastRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
// The root rank to broadcast from.
int32 root = 4;
}
message BroadcastReply {
bytes receive_buffer = 1;
}

View File

@ -0,0 +1,104 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <federated.pb.h>
#include <grpcpp/grpcpp.h>
#include <cstdio>
#include <cstdlib>
#include <string>
namespace xgboost {
namespace federated {
/**
* @brief A wrapper around the gRPC client.
*/
class FederatedClient {
public:
FederatedClient(std::string const &server_address, int rank, std::string const &server_cert,
std::string const &client_key, std::string const &client_cert)
: stub_{[&] {
grpc::SslCredentialsOptions options;
options.pem_root_certs = server_cert;
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
return Federated::NewStub(
grpc::CreateChannel(server_address, grpc::SslCredentials(options)));
}()},
rank_{rank} {}
/** @brief Insecure client for testing only. */
FederatedClient(std::string const &server_address, int rank)
: stub_{Federated::NewStub(
grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))},
rank_{rank} {}
std::string Allgather(std::string const &send_buffer) {
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
AllgatherReply reply;
grpc::ClientContext context;
grpc::Status status = stub_->Allgather(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allgather RPC failed");
}
}
std::string Allreduce(std::string const &send_buffer, DataType data_type,
ReduceOperation reduce_operation) {
AllreduceRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
request.set_data_type(data_type);
request.set_reduce_operation(reduce_operation);
AllreduceReply reply;
grpc::ClientContext context;
grpc::Status status = stub_->Allreduce(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allreduce RPC failed");
}
}
std::string Broadcast(std::string const &send_buffer, int root) {
BroadcastRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
request.set_root(root);
BroadcastReply reply;
grpc::ClientContext context;
grpc::Status status = stub_->Broadcast(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Broadcast RPC failed");
}
}
private:
std::unique_ptr<Federated::Stub> const stub_;
int const rank_;
uint64_t sequence_number_{};
};
} // namespace federated
} // namespace xgboost

View File

@ -0,0 +1,234 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "federated_server.h"
#include <grpcpp/grpcpp.h>
#include <grpcpp/server_builder.h>
#include <xgboost/logging.h>
#include <fstream>
#include <sstream>
namespace xgboost {
namespace federated {
class AllgatherFunctor {
public:
std::string const name{"Allgather"};
explicit AllgatherFunctor(int const world_size) : world_size_{world_size} {}
void operator()(AllgatherRequest const* request, std::string& buffer) const {
auto const rank = request->rank();
auto const& send_buffer = request->send_buffer();
auto const send_size = send_buffer.size();
// Resize the buffer if this is the first request.
if (buffer.size() != send_size * world_size_) {
buffer.resize(send_size * world_size_);
}
// Splice the send_buffer into the common buffer.
buffer.replace(rank * send_size, send_size, send_buffer);
}
private:
int const world_size_;
};
class AllreduceFunctor {
public:
std::string const name{"Allreduce"};
void operator()(AllreduceRequest const* request, std::string& buffer) const {
if (buffer.empty()) {
// Copy the send_buffer if this is the first request.
buffer = request->send_buffer();
} else {
// Apply the reduce_operation to the send_buffer and the common buffer.
Accumulate(buffer, request->send_buffer(), request->data_type(), request->reduce_operation());
}
}
private:
template <class T>
void Accumulate(T* buffer, T const* input, std::size_t n,
ReduceOperation reduce_operation) const {
switch (reduce_operation) {
case ReduceOperation::MAX:
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::max(a, b); });
break;
case ReduceOperation::MIN:
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::min(a, b); });
break;
case ReduceOperation::SUM:
std::transform(buffer, buffer + n, input, buffer, std::plus<T>());
break;
default:
throw std::invalid_argument("Invalid reduce operation");
}
}
void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
ReduceOperation reduce_operation) const {
switch (data_type) {
case DataType::CHAR:
Accumulate(&buffer[0], reinterpret_cast<char const*>(input.data()), buffer.size(),
reduce_operation);
break;
case DataType::UCHAR:
Accumulate(reinterpret_cast<unsigned char*>(&buffer[0]),
reinterpret_cast<unsigned char const*>(input.data()), buffer.size(),
reduce_operation);
break;
case DataType::INT:
Accumulate(reinterpret_cast<int*>(&buffer[0]), reinterpret_cast<int const*>(input.data()),
buffer.size() / sizeof(int), reduce_operation);
break;
case DataType::UINT:
Accumulate(reinterpret_cast<unsigned int*>(&buffer[0]),
reinterpret_cast<unsigned int const*>(input.data()),
buffer.size() / sizeof(unsigned int), reduce_operation);
break;
case DataType::LONG:
Accumulate(reinterpret_cast<long*>(&buffer[0]), reinterpret_cast<long const*>(input.data()),
buffer.size() / sizeof(long), reduce_operation);
break;
case DataType::ULONG:
Accumulate(reinterpret_cast<unsigned long*>(&buffer[0]),
reinterpret_cast<unsigned long const*>(input.data()),
buffer.size() / sizeof(unsigned long), reduce_operation);
break;
case DataType::FLOAT:
Accumulate(reinterpret_cast<float*>(&buffer[0]),
reinterpret_cast<float const*>(input.data()), buffer.size() / sizeof(float),
reduce_operation);
break;
case DataType::DOUBLE:
Accumulate(reinterpret_cast<double*>(&buffer[0]),
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
reduce_operation);
break;
case DataType::LONGLONG:
Accumulate(reinterpret_cast<long long*>(&buffer[0]),
reinterpret_cast<long long const*>(input.data()),
buffer.size() / sizeof(long long), reduce_operation);
break;
case DataType::ULONGLONG:
Accumulate(reinterpret_cast<unsigned long long*>(&buffer[0]),
reinterpret_cast<unsigned long long const*>(input.data()),
buffer.size() / sizeof(unsigned long long), reduce_operation);
break;
default:
throw std::invalid_argument("Invalid data type");
}
}
};
class BroadcastFunctor {
public:
std::string const name{"Broadcast"};
void operator()(BroadcastRequest const* request, std::string& buffer) const {
if (request->rank() == request->root()) {
// Copy the send_buffer if this is the root.
buffer = request->send_buffer();
}
}
};
grpc::Status FederatedService::Allgather(grpc::ServerContext* context,
AllgatherRequest const* request, AllgatherReply* reply) {
return Handle(request, reply, AllgatherFunctor{world_size_});
}
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
AllreduceRequest const* request, AllreduceReply* reply) {
return Handle(request, reply, AllreduceFunctor{});
}
grpc::Status FederatedService::Broadcast(grpc::ServerContext* context,
BroadcastRequest const* request, BroadcastReply* reply) {
return Handle(request, reply, BroadcastFunctor{});
}
template <class Request, class Reply, class RequestFunctor>
grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
RequestFunctor const& functor) {
// Pass through if there is only 1 client.
if (world_size_ == 1) {
reply->set_receive_buffer(request->send_buffer());
return grpc::Status::OK;
}
std::unique_lock<std::mutex> lock(mutex_);
auto const sequence_number = request->sequence_number();
auto const rank = request->rank();
LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number";
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
LOG(INFO) << functor.name << " rank " << rank << ": handling request";
functor(request, buffer_);
received_++;
if (received_ == world_size_) {
LOG(INFO) << functor.name << " rank " << rank << ": all requests received";
reply->set_receive_buffer(buffer_);
sent_++;
lock.unlock();
cv_.notify_all();
return grpc::Status::OK;
}
LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients";
cv_.wait(lock, [this] { return received_ == world_size_; });
LOG(INFO) << functor.name << " rank " << rank << ": sending reply";
reply->set_receive_buffer(buffer_);
sent_++;
if (sent_ == world_size_) {
LOG(INFO) << functor.name << " rank " << rank << ": all replies sent";
sent_ = 0;
received_ = 0;
buffer_.clear();
sequence_number_++;
lock.unlock();
cv_.notify_all();
}
return grpc::Status::OK;
}
std::string ReadFile(char const* path) {
auto stream = std::ifstream(path);
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
grpc::ServerBuilder builder;
auto options =
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = ReadFile(client_cert_file);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = ReadFile(server_key_file);
key.cert_chain = ReadFile(server_cert_file);
options.pem_key_cert_pairs.push_back(key);
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
<< world_size;
server->Wait();
}
} // namespace federated
} // namespace xgboost

View File

@ -0,0 +1,44 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <condition_variable>
#include <mutex>
namespace xgboost {
namespace federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(int const world_size) : world_size_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
BroadcastReply* reply) override;
private:
template <class Request, class Reply, class RequestFunctor>
grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor);
int const world_size_;
int received_{};
int sent_{};
std::string buffer_{};
uint64_t sequence_number_{};
mutable std::mutex mutex_;
mutable std::condition_variable cv_;
};
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file);
} // namespace federated
} // namespace xgboost

View File

@ -0,0 +1,36 @@
"""XGBoost Federated Learning related API."""
from .core import _LIB, _check_call, c_str, build_info, XGBoostError
def run_federated_server(port: int,
world_size: int,
server_key_path: str,
server_cert_path: str,
client_cert_path: str) -> None:
"""Run the Federated Learning server.
Parameters
----------
port : int
The port to listen on.
world_size: int
The number of federated workers.
server_key_path: str
Path to the server private key file.
server_cert_path: str
Path to the server certificate file.
client_cert_path: str
Path to the client certificate file.
"""
if build_info()['USE_FEDERATED']:
_check_call(_LIB.XGBRunFederatedServer(port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path)))
else:
raise XGBoostError(
"XGBoost needs to be built with the federated learning plugin "
"enabled in order to use this module"
)

View File

@ -6,7 +6,9 @@ set(RABIT_SOURCES
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc)
if (RABIT_BUILD_MPI)
if (PLUGIN_FEDERATED)
# Skip the engine if the Federated Learning plugin is enabled.
elseif (RABIT_BUILD_MPI)
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc)
elseif (RABIT_MOCK)
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)

View File

@ -28,6 +28,10 @@
#include "../data/simple_dmatrix.h"
#include "../data/proxy_dmatrix.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_server.h"
#endif
using namespace xgboost; // NOLINT(*);
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
@ -95,6 +99,12 @@ XGB_DLL int XGBuildInfo(char const **out) {
info["DEBUG"] = Boolean{false};
#endif
#if defined(XGBOOST_USE_FEDERATED)
info["USE_FEDERATED"] = Boolean{true};
#else
info["USE_FEDERATED"] = Boolean{false};
#endif
XGBBuildInfoDevice(&info);
auto &out_str = GlobalConfigAPIThreadLocalStore::Get()->ret_str;
@ -198,11 +208,15 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
DMatrixHandle *out) {
API_BEGIN();
bool load_row_split = false;
#if defined(XGBOOST_USE_FEDERATED)
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
#else
if (rabit::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, "
<< "will split data among workers";
load_row_split = true;
}
#endif
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
API_END();
}
@ -1342,5 +1356,14 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
API_END();
}
#if defined(XGBOOST_USE_FEDERATED)
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
char const *server_cert_path, char const *client_cert_path) {
API_BEGIN();
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
API_END();
}
#endif
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();

View File

@ -18,6 +18,14 @@ if (NOT PLUGIN_UPDATER_ONEAPI)
list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES})
endif (NOT PLUGIN_UPDATER_ONEAPI)
if (PLUGIN_FEDERATED)
target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated)
target_link_libraries(testxgboost PRIVATE federated_client)
else (PLUGIN_FEDERATED)
file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.cc")
list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES})
endif (PLUGIN_FEDERATED)
target_sources(testxgboost PRIVATE ${TEST_SOURCES} ${xgboost_SOURCE_DIR}/plugin/example/custom_obj.cc)
if (USE_CUDA AND PLUGIN_RMM)

View File

@ -0,0 +1,130 @@
/*!
* Copyright 2017-2020 XGBoost contributors
*/
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>
#include <thread>
#include "federated_client.h"
#include "federated_server.h"
namespace xgboost {
class FederatedServerTest : public ::testing::Test {
public:
static void VerifyAllgather(int rank) {
federated::FederatedClient client{kServerAddress, rank};
CheckAllgather(client, rank);
}
static void VerifyAllreduce(int rank) {
federated::FederatedClient client{kServerAddress, rank};
CheckAllreduce(client);
}
static void VerifyBroadcast(int rank) {
federated::FederatedClient client{kServerAddress, rank};
CheckBroadcast(client, rank);
}
static void VerifyMixture(int rank) {
federated::FederatedClient client{kServerAddress, rank};
for (auto i = 0; i < 10; i++) {
CheckAllgather(client, rank);
CheckAllreduce(client);
CheckBroadcast(client, rank);
}
}
protected:
void SetUp() override {
server_thread_.reset(new std::thread([this] {
grpc::ServerBuilder builder;
federated::FederatedService service{kWorldSize};
builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
server_->Wait();
}));
}
void TearDown() override {
server_->Shutdown();
server_thread_->join();
}
static void CheckAllgather(federated::FederatedClient& client, int rank) {
auto reply = client.Allgather("hello " + std::to_string(rank) + " ");
EXPECT_EQ(reply, "hello 0 hello 1 hello 2 ");
}
static void CheckAllreduce(federated::FederatedClient& client) {
int data[] = {1, 2, 3, 4, 5};
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allreduce(send_buffer, federated::INT, federated::SUM);
auto const* result = reinterpret_cast<int const*>(reply.data());
int expected[] = {3, 6, 9, 12, 15};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(result[i], expected[i]);
}
}
static void CheckBroadcast(federated::FederatedClient& client, int rank) {
std::string send_buffer{};
if (rank == 0) {
send_buffer = "hello broadcast";
}
auto reply = client.Broadcast(send_buffer, 0);
EXPECT_EQ(reply, "hello broadcast");
}
static int const kWorldSize{3};
static std::string const kServerAddress;
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
};
std::string const FederatedServerTest::kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp)
TEST_F(FederatedServerTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank));
}
for (auto& thread : threads) {
thread.join();
}
}
TEST_F(FederatedServerTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllreduce, rank));
}
for (auto& thread : threads) {
thread.join();
}
}
TEST_F(FederatedServerTest, Broadcast) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyBroadcast, rank));
}
for (auto& thread : threads) {
thread.join();
}
}
TEST_F(FederatedServerTest, Mixture) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyMixture, rank));
}
for (auto& thread : threads) {
thread.join();
}
}
} // namespace xgboost

View File

@ -0,0 +1,17 @@
#!/bin/bash
set -e
rm -f ./*.model* ./agaricus* ./*.pem
world_size=3
# Generate server and client certificates.
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
# Split train and test files manually to simulate a federated environment.
split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train-
split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test-
python test_federated.py ${world_size}

View File

@ -0,0 +1,78 @@
#!/usr/bin/python
import multiprocessing
import sys
import time
import xgboost as xgb
import xgboost.federated
SERVER_KEY = 'server-key.pem'
SERVER_CERT = 'server-cert.pem'
CLIENT_KEY = 'client-key.pem'
CLIENT_CERT = 'client-cert.pem'
def run_server(port: int, world_size: int) -> None:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
CLIENT_CERT)
def run_worker(port: int, world_size: int, rank: int) -> None:
# Always call this before using distributed module
rabit_env = [
f'federated_server_address=localhost:{port}',
f'federated_world_size={world_size}',
f'federated_rank={rank}',
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
]
xgb.rabit.init([e.encode() for e in rabit_env])
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20
# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model.json")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
def run_test() -> None:
port = 9091
world_size = int(sys.argv[1])
server = multiprocessing.Process(target=run_server, args=(port, world_size))
server.start()
time.sleep(1)
if not server.is_alive():
raise Exception("Error starting Federated Learning server")
workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
server.terminate()
if __name__ == '__main__':
run_test()