Initial support for federated learning (#7831)
Federated learning plugin for xgboost: * A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers. * A Rabit engine for the federated environment. * Integration test to simulate federated learning. Additional followups are needed to address GPU support, better security, and privacy, etc.
This commit is contained in:
parent
46e0bce212
commit
14ef38b834
@ -66,6 +66,7 @@ address, leak, undefined and thread.")
|
|||||||
## Plugins
|
## Plugins
|
||||||
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
|
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
|
||||||
option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF)
|
option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF)
|
||||||
|
option(PLUGIN_FEDERATED "Build with Federated Learning" OFF)
|
||||||
## TODO: 1. Add check if DPC++ compiler is used for building
|
## TODO: 1. Add check if DPC++ compiler is used for building
|
||||||
option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF)
|
option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF)
|
||||||
option(ADD_PKGCONFIG "Add xgboost.pc into system." ON)
|
option(ADD_PKGCONFIG "Add xgboost.pc into system." ON)
|
||||||
|
|||||||
@ -40,3 +40,8 @@ if (PLUGIN_UPDATER_ONEAPI)
|
|||||||
# Add all objects of oneapi_plugin to objxgboost
|
# Add all objects of oneapi_plugin to objxgboost
|
||||||
target_sources(objxgboost INTERFACE $<TARGET_OBJECTS:oneapi_plugin>)
|
target_sources(objxgboost INTERFACE $<TARGET_OBJECTS:oneapi_plugin>)
|
||||||
endif (PLUGIN_UPDATER_ONEAPI)
|
endif (PLUGIN_UPDATER_ONEAPI)
|
||||||
|
|
||||||
|
# Add the Federate Learning plugin if enabled.
|
||||||
|
if (PLUGIN_FEDERATED)
|
||||||
|
add_subdirectory(federated)
|
||||||
|
endif (PLUGIN_FEDERATED)
|
||||||
|
|||||||
27
plugin/federated/CMakeLists.txt
Normal file
27
plugin/federated/CMakeLists.txt
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# gRPC needs to be installed first. See README.md.
|
||||||
|
find_package(Protobuf REQUIRED)
|
||||||
|
find_package(gRPC REQUIRED)
|
||||||
|
find_package(Threads)
|
||||||
|
|
||||||
|
# Generated code from the protobuf definition.
|
||||||
|
add_library(federated_proto federated.proto)
|
||||||
|
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
|
||||||
|
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
|
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
|
||||||
|
protobuf_generate(TARGET federated_proto LANGUAGE cpp)
|
||||||
|
protobuf_generate(
|
||||||
|
TARGET federated_proto
|
||||||
|
LANGUAGE grpc
|
||||||
|
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
|
||||||
|
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}")
|
||||||
|
|
||||||
|
# Wrapper for the gRPC client.
|
||||||
|
add_library(federated_client INTERFACE federated_client.h)
|
||||||
|
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||||
|
|
||||||
|
# Rabit engine for Federated Learning.
|
||||||
|
target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc)
|
||||||
|
target_link_libraries(objxgboost PRIVATE federated_client)
|
||||||
|
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
||||||
35
plugin/federated/README.md
Normal file
35
plugin/federated/README.md
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
XGBoost Plugin for Federated Learning
|
||||||
|
=====================================
|
||||||
|
|
||||||
|
This folder contains the plugin for federated learning. Follow these steps to build and test it.
|
||||||
|
|
||||||
|
Install gRPC
|
||||||
|
------------
|
||||||
|
```shell
|
||||||
|
sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build
|
||||||
|
git clone -b v1.45.2 https://github.com/grpc/grpc
|
||||||
|
cd grpc
|
||||||
|
git submodule update --init
|
||||||
|
cmake -S . -B build -GNinja -DABSL_PROPAGATE_CXX_STD=ON
|
||||||
|
cmake --build build --target install
|
||||||
|
```
|
||||||
|
|
||||||
|
Build the Plugin
|
||||||
|
----------------
|
||||||
|
```shell
|
||||||
|
# Under xgboost source tree.
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake .. -GNinja -DPLUGIN_FEDERATED=ON
|
||||||
|
ninja
|
||||||
|
cd ../python-package
|
||||||
|
pip install -e . # or equivalently python setup.py develop
|
||||||
|
```
|
||||||
|
|
||||||
|
Test Federated XGBoost
|
||||||
|
----------------------
|
||||||
|
```shell
|
||||||
|
# Under xgboost source tree.
|
||||||
|
cd tests/distributed
|
||||||
|
./runtests-federated.sh
|
||||||
|
```
|
||||||
274
plugin/federated/engine_federated.cc
Normal file
274
plugin/federated/engine_federated.cc
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <cstdio>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "federated_client.h"
|
||||||
|
#include "rabit/internal/engine.h"
|
||||||
|
#include "rabit/internal/utils.h"
|
||||||
|
|
||||||
|
namespace MPI { // NOLINT
|
||||||
|
// MPI data type to be compatible with existing MPI interface
|
||||||
|
class Datatype {
|
||||||
|
public:
|
||||||
|
size_t type_size;
|
||||||
|
explicit Datatype(size_t type_size) : type_size(type_size) {}
|
||||||
|
};
|
||||||
|
} // namespace MPI
|
||||||
|
|
||||||
|
namespace rabit {
|
||||||
|
namespace engine {
|
||||||
|
|
||||||
|
/*! \brief implementation of engine using federated learning */
|
||||||
|
class FederatedEngine : public IEngine {
|
||||||
|
public:
|
||||||
|
void Init(int argc, char *argv[]) {
|
||||||
|
// Parse environment variables first.
|
||||||
|
for (auto const &env_var : env_vars_) {
|
||||||
|
char const *value = getenv(env_var.c_str());
|
||||||
|
if (value != nullptr) {
|
||||||
|
SetParam(env_var, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Command line argument overrides.
|
||||||
|
for (int i = 0; i < argc; ++i) {
|
||||||
|
std::string const key_value = argv[i];
|
||||||
|
auto const delimiter = key_value.find('=');
|
||||||
|
if (delimiter != std::string::npos) {
|
||||||
|
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
|
||||||
|
server_address_.c_str(), world_size_, rank_);
|
||||||
|
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
|
||||||
|
client_key_, client_cert_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Finalize() { client_.reset(); }
|
||||||
|
|
||||||
|
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end,
|
||||||
|
size_t size_prev_slice) override {
|
||||||
|
throw std::logic_error("FederatedEngine:: Allgather is not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Allgather(void *sendbuf, size_t total_size) {
|
||||||
|
std::string const send_buffer(reinterpret_cast<char *>(sendbuf), total_size);
|
||||||
|
return client_->Allgather(send_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer,
|
||||||
|
PreprocFunction prepare_fun, void *prepare_arg) override {
|
||||||
|
throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
|
||||||
|
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
|
||||||
|
std::string const send_buffer(buffer, size);
|
||||||
|
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
|
||||||
|
receive_buffer.copy(buffer, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
int GetRingPrevRank() const override {
|
||||||
|
throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Broadcast(void *sendrecvbuf, size_t size, int root) override {
|
||||||
|
if (world_size_ == 1) return;
|
||||||
|
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
|
||||||
|
std::string const send_buffer(buffer, size);
|
||||||
|
auto const receive_buffer = client_->Broadcast(send_buffer, root);
|
||||||
|
if (rank_ != root) {
|
||||||
|
receive_buffer.copy(buffer, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int LoadCheckPoint(Serializable *global_model, Serializable *local_model = nullptr) override {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckPoint(const Serializable *global_model,
|
||||||
|
const Serializable *local_model = nullptr) override {
|
||||||
|
version_number_ += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LazyCheckPoint(const Serializable *global_model) override { version_number_ += 1; }
|
||||||
|
|
||||||
|
int VersionNumber() const override { return version_number_; }
|
||||||
|
|
||||||
|
/*! \brief get rank of current node */
|
||||||
|
int GetRank() const override { return rank_; }
|
||||||
|
|
||||||
|
/*! \brief get total number of */
|
||||||
|
int GetWorldSize() const override { return world_size_; }
|
||||||
|
|
||||||
|
/*! \brief whether it is distributed */
|
||||||
|
bool IsDistributed() const override { return true; }
|
||||||
|
|
||||||
|
/*! \brief get the host name of current node */
|
||||||
|
std::string GetHost() const override { return "rank" + std::to_string(rank_); }
|
||||||
|
|
||||||
|
void TrackerPrint(const std::string &msg) override {
|
||||||
|
// simply print information into the tracker
|
||||||
|
if (GetRank() == 0) {
|
||||||
|
utils::Printf("%s", msg.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
|
||||||
|
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
|
||||||
|
switch (data_type) {
|
||||||
|
case mpi::kChar:
|
||||||
|
return xgboost::federated::CHAR;
|
||||||
|
case mpi::kUChar:
|
||||||
|
return xgboost::federated::UCHAR;
|
||||||
|
case mpi::kInt:
|
||||||
|
return xgboost::federated::INT;
|
||||||
|
case mpi::kUInt:
|
||||||
|
return xgboost::federated::UINT;
|
||||||
|
case mpi::kLong:
|
||||||
|
return xgboost::federated::LONG;
|
||||||
|
case mpi::kULong:
|
||||||
|
return xgboost::federated::ULONG;
|
||||||
|
case mpi::kFloat:
|
||||||
|
return xgboost::federated::FLOAT;
|
||||||
|
case mpi::kDouble:
|
||||||
|
return xgboost::federated::DOUBLE;
|
||||||
|
case mpi::kLongLong:
|
||||||
|
return xgboost::federated::LONGLONG;
|
||||||
|
case mpi::kULongLong:
|
||||||
|
return xgboost::federated::ULONGLONG;
|
||||||
|
}
|
||||||
|
utils::Error("unknown mpi::DataType");
|
||||||
|
return xgboost::federated::CHAR;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief Transform mpi::OpType to enum to MPI OP */
|
||||||
|
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
|
||||||
|
switch (op_type) {
|
||||||
|
case mpi::kMax:
|
||||||
|
return xgboost::federated::MAX;
|
||||||
|
case mpi::kMin:
|
||||||
|
return xgboost::federated::MIN;
|
||||||
|
case mpi::kSum:
|
||||||
|
return xgboost::federated::SUM;
|
||||||
|
case mpi::kBitwiseOR:
|
||||||
|
utils::Error("Bitwise OR is not supported");
|
||||||
|
return xgboost::federated::MAX;
|
||||||
|
}
|
||||||
|
utils::Error("unknown mpi::OpType");
|
||||||
|
return xgboost::federated::MAX;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetParam(std::string const &name, std::string const &val) {
|
||||||
|
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
|
||||||
|
server_address_ = val;
|
||||||
|
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) {
|
||||||
|
world_size_ = std::stoi(val);
|
||||||
|
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) {
|
||||||
|
rank_ = std::stoi(val);
|
||||||
|
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) {
|
||||||
|
server_cert_ = ReadFile(val);
|
||||||
|
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) {
|
||||||
|
client_key_ = ReadFile(val);
|
||||||
|
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) {
|
||||||
|
client_cert_ = ReadFile(val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string ReadFile(std::string const &path) {
|
||||||
|
auto stream = std::ifstream(path.data());
|
||||||
|
std::ostringstream out;
|
||||||
|
out << stream.rdbuf();
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
std::vector<std::string> const env_vars_{
|
||||||
|
"FEDERATED_SERVER_ADDRESS",
|
||||||
|
"FEDERATED_WORLD_SIZE",
|
||||||
|
"FEDERATED_RANK",
|
||||||
|
"FEDERATED_SERVER_CERT",
|
||||||
|
"FEDERATED_CLIENT_KEY",
|
||||||
|
"FEDERATED_CLIENT_CERT" };
|
||||||
|
// clang-format on
|
||||||
|
std::string server_address_{"localhost:9091"};
|
||||||
|
int world_size_{1};
|
||||||
|
int rank_{0};
|
||||||
|
std::string server_cert_{};
|
||||||
|
std::string client_key_{};
|
||||||
|
std::string client_cert_{};
|
||||||
|
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
|
||||||
|
int version_number_{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Singleton federated engine.
|
||||||
|
FederatedEngine engine; // NOLINT(cert-err58-cpp)
|
||||||
|
|
||||||
|
/*! \brief initialize the synchronization module */
|
||||||
|
bool Init(int argc, char *argv[]) {
|
||||||
|
try {
|
||||||
|
engine.Init(argc, argv);
|
||||||
|
return true;
|
||||||
|
} catch (std::exception const &e) {
|
||||||
|
fprintf(stderr, " failed in federated Init %s\n", e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief finalize synchronization module */
|
||||||
|
bool Finalize() {
|
||||||
|
try {
|
||||||
|
engine.Finalize();
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
fprintf(stderr, "failed in federated shutdown %s\n", e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief singleton method to get engine */
|
||||||
|
IEngine *GetEngine() { return &engine; }
|
||||||
|
|
||||||
|
// perform in-place allreduce, on sendrecvbuf
|
||||||
|
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red,
|
||||||
|
mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun,
|
||||||
|
void *prepare_arg) {
|
||||||
|
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||||
|
if (engine.GetWorldSize() == 1) return;
|
||||||
|
engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
ReduceHandle::ReduceHandle() = default;
|
||||||
|
ReduceHandle::~ReduceHandle() = default;
|
||||||
|
|
||||||
|
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast<int>(dtype.type_size); }
|
||||||
|
|
||||||
|
void ReduceHandle::Init(IEngine::ReduceFunction redfunc,
|
||||||
|
__attribute__((unused)) size_t type_nbytes) {
|
||||||
|
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
|
||||||
|
redfunc_ = redfunc;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count,
|
||||||
|
IEngine::PreprocFunction prepare_fun, void *prepare_arg) {
|
||||||
|
utils::Assert(redfunc_ != nullptr, "must initialize handle to call AllReduce");
|
||||||
|
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||||
|
if (engine.GetWorldSize() == 1) return;
|
||||||
|
|
||||||
|
// Gather all the buffers and call the reduce function locally.
|
||||||
|
auto const buffer_size = type_nbytes * count;
|
||||||
|
auto const gathered = engine.Allgather(sendrecvbuf, buffer_size);
|
||||||
|
auto const *data = gathered.data();
|
||||||
|
for (int i = 0; i < engine.GetWorldSize(); i++) {
|
||||||
|
if (i != engine.GetRank()) {
|
||||||
|
redfunc_(data + buffer_size * i, sendrecvbuf, static_cast<int>(count),
|
||||||
|
MPI::Datatype(type_nbytes));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace engine
|
||||||
|
} // namespace rabit
|
||||||
68
plugin/federated/federated.proto
Normal file
68
plugin/federated/federated.proto
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package xgboost.federated;
|
||||||
|
|
||||||
|
service Federated {
|
||||||
|
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||||
|
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
||||||
|
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum DataType {
|
||||||
|
CHAR = 0;
|
||||||
|
UCHAR = 1;
|
||||||
|
INT = 2;
|
||||||
|
UINT = 3;
|
||||||
|
LONG = 4;
|
||||||
|
ULONG = 5;
|
||||||
|
FLOAT = 6;
|
||||||
|
DOUBLE = 7;
|
||||||
|
LONGLONG = 8;
|
||||||
|
ULONGLONG = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ReduceOperation {
|
||||||
|
MAX = 0;
|
||||||
|
MIN = 1;
|
||||||
|
SUM = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AllgatherRequest {
|
||||||
|
// An incrementing counter that is unique to each round to operations.
|
||||||
|
uint64 sequence_number = 1;
|
||||||
|
int32 rank = 2;
|
||||||
|
bytes send_buffer = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AllgatherReply {
|
||||||
|
bytes receive_buffer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AllreduceRequest {
|
||||||
|
// An incrementing counter that is unique to each round to operations.
|
||||||
|
uint64 sequence_number = 1;
|
||||||
|
int32 rank = 2;
|
||||||
|
bytes send_buffer = 3;
|
||||||
|
DataType data_type = 4;
|
||||||
|
ReduceOperation reduce_operation = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AllreduceReply {
|
||||||
|
bytes receive_buffer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message BroadcastRequest {
|
||||||
|
// An incrementing counter that is unique to each round to operations.
|
||||||
|
uint64 sequence_number = 1;
|
||||||
|
int32 rank = 2;
|
||||||
|
bytes send_buffer = 3;
|
||||||
|
// The root rank to broadcast from.
|
||||||
|
int32 root = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message BroadcastReply {
|
||||||
|
bytes receive_buffer = 1;
|
||||||
|
}
|
||||||
104
plugin/federated/federated_client.h
Normal file
104
plugin/federated/federated_client.h
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <federated.grpc.pb.h>
|
||||||
|
#include <federated.pb.h>
|
||||||
|
#include <grpcpp/grpcpp.h>
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace federated {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A wrapper around the gRPC client.
|
||||||
|
*/
|
||||||
|
class FederatedClient {
|
||||||
|
public:
|
||||||
|
FederatedClient(std::string const &server_address, int rank, std::string const &server_cert,
|
||||||
|
std::string const &client_key, std::string const &client_cert)
|
||||||
|
: stub_{[&] {
|
||||||
|
grpc::SslCredentialsOptions options;
|
||||||
|
options.pem_root_certs = server_cert;
|
||||||
|
options.pem_private_key = client_key;
|
||||||
|
options.pem_cert_chain = client_cert;
|
||||||
|
return Federated::NewStub(
|
||||||
|
grpc::CreateChannel(server_address, grpc::SslCredentials(options)));
|
||||||
|
}()},
|
||||||
|
rank_{rank} {}
|
||||||
|
|
||||||
|
/** @brief Insecure client for testing only. */
|
||||||
|
FederatedClient(std::string const &server_address, int rank)
|
||||||
|
: stub_{Federated::NewStub(
|
||||||
|
grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))},
|
||||||
|
rank_{rank} {}
|
||||||
|
|
||||||
|
std::string Allgather(std::string const &send_buffer) {
|
||||||
|
AllgatherRequest request;
|
||||||
|
request.set_sequence_number(sequence_number_++);
|
||||||
|
request.set_rank(rank_);
|
||||||
|
request.set_send_buffer(send_buffer);
|
||||||
|
|
||||||
|
AllgatherReply reply;
|
||||||
|
grpc::ClientContext context;
|
||||||
|
grpc::Status status = stub_->Allgather(&context, request, &reply);
|
||||||
|
|
||||||
|
if (status.ok()) {
|
||||||
|
return reply.receive_buffer();
|
||||||
|
} else {
|
||||||
|
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||||
|
throw std::runtime_error("Allgather RPC failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Allreduce(std::string const &send_buffer, DataType data_type,
|
||||||
|
ReduceOperation reduce_operation) {
|
||||||
|
AllreduceRequest request;
|
||||||
|
request.set_sequence_number(sequence_number_++);
|
||||||
|
request.set_rank(rank_);
|
||||||
|
request.set_send_buffer(send_buffer);
|
||||||
|
request.set_data_type(data_type);
|
||||||
|
request.set_reduce_operation(reduce_operation);
|
||||||
|
|
||||||
|
AllreduceReply reply;
|
||||||
|
grpc::ClientContext context;
|
||||||
|
grpc::Status status = stub_->Allreduce(&context, request, &reply);
|
||||||
|
|
||||||
|
if (status.ok()) {
|
||||||
|
return reply.receive_buffer();
|
||||||
|
} else {
|
||||||
|
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||||
|
throw std::runtime_error("Allreduce RPC failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Broadcast(std::string const &send_buffer, int root) {
|
||||||
|
BroadcastRequest request;
|
||||||
|
request.set_sequence_number(sequence_number_++);
|
||||||
|
request.set_rank(rank_);
|
||||||
|
request.set_send_buffer(send_buffer);
|
||||||
|
request.set_root(root);
|
||||||
|
|
||||||
|
BroadcastReply reply;
|
||||||
|
grpc::ClientContext context;
|
||||||
|
grpc::Status status = stub_->Broadcast(&context, request, &reply);
|
||||||
|
|
||||||
|
if (status.ok()) {
|
||||||
|
return reply.receive_buffer();
|
||||||
|
} else {
|
||||||
|
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||||
|
throw std::runtime_error("Broadcast RPC failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<Federated::Stub> const stub_;
|
||||||
|
int const rank_;
|
||||||
|
uint64_t sequence_number_{};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace federated
|
||||||
|
} // namespace xgboost
|
||||||
234
plugin/federated/federated_server.cc
Normal file
234
plugin/federated/federated_server.cc
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "federated_server.h"
|
||||||
|
|
||||||
|
#include <grpcpp/grpcpp.h>
|
||||||
|
#include <grpcpp/server_builder.h>
|
||||||
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace federated {
|
||||||
|
|
||||||
|
class AllgatherFunctor {
|
||||||
|
public:
|
||||||
|
std::string const name{"Allgather"};
|
||||||
|
|
||||||
|
explicit AllgatherFunctor(int const world_size) : world_size_{world_size} {}
|
||||||
|
|
||||||
|
void operator()(AllgatherRequest const* request, std::string& buffer) const {
|
||||||
|
auto const rank = request->rank();
|
||||||
|
auto const& send_buffer = request->send_buffer();
|
||||||
|
auto const send_size = send_buffer.size();
|
||||||
|
// Resize the buffer if this is the first request.
|
||||||
|
if (buffer.size() != send_size * world_size_) {
|
||||||
|
buffer.resize(send_size * world_size_);
|
||||||
|
}
|
||||||
|
// Splice the send_buffer into the common buffer.
|
||||||
|
buffer.replace(rank * send_size, send_size, send_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int const world_size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllreduceFunctor {
|
||||||
|
public:
|
||||||
|
std::string const name{"Allreduce"};
|
||||||
|
|
||||||
|
void operator()(AllreduceRequest const* request, std::string& buffer) const {
|
||||||
|
if (buffer.empty()) {
|
||||||
|
// Copy the send_buffer if this is the first request.
|
||||||
|
buffer = request->send_buffer();
|
||||||
|
} else {
|
||||||
|
// Apply the reduce_operation to the send_buffer and the common buffer.
|
||||||
|
Accumulate(buffer, request->send_buffer(), request->data_type(), request->reduce_operation());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <class T>
|
||||||
|
void Accumulate(T* buffer, T const* input, std::size_t n,
|
||||||
|
ReduceOperation reduce_operation) const {
|
||||||
|
switch (reduce_operation) {
|
||||||
|
case ReduceOperation::MAX:
|
||||||
|
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::max(a, b); });
|
||||||
|
break;
|
||||||
|
case ReduceOperation::MIN:
|
||||||
|
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::min(a, b); });
|
||||||
|
break;
|
||||||
|
case ReduceOperation::SUM:
|
||||||
|
std::transform(buffer, buffer + n, input, buffer, std::plus<T>());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid reduce operation");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
|
||||||
|
ReduceOperation reduce_operation) const {
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::CHAR:
|
||||||
|
Accumulate(&buffer[0], reinterpret_cast<char const*>(input.data()), buffer.size(),
|
||||||
|
reduce_operation);
|
||||||
|
break;
|
||||||
|
case DataType::UCHAR:
|
||||||
|
Accumulate(reinterpret_cast<unsigned char*>(&buffer[0]),
|
||||||
|
reinterpret_cast<unsigned char const*>(input.data()), buffer.size(),
|
||||||
|
reduce_operation);
|
||||||
|
break;
|
||||||
|
case DataType::INT:
|
||||||
|
Accumulate(reinterpret_cast<int*>(&buffer[0]), reinterpret_cast<int const*>(input.data()),
|
||||||
|
buffer.size() / sizeof(int), reduce_operation);
|
||||||
|
break;
|
||||||
|
case DataType::UINT:
|
||||||
|
Accumulate(reinterpret_cast<unsigned int*>(&buffer[0]),
|
||||||
|
reinterpret_cast<unsigned int const*>(input.data()),
|
||||||
|
buffer.size() / sizeof(unsigned int), reduce_operation);
|
||||||
|
break;
|
||||||
|
case DataType::LONG:
|
||||||
|
Accumulate(reinterpret_cast<long*>(&buffer[0]), reinterpret_cast<long const*>(input.data()),
|
||||||
|
buffer.size() / sizeof(long), reduce_operation);
|
||||||
|
break;
|
||||||
|
case DataType::ULONG:
|
||||||
|
Accumulate(reinterpret_cast<unsigned long*>(&buffer[0]),
|
||||||
|
reinterpret_cast<unsigned long const*>(input.data()),
|
||||||
|
buffer.size() / sizeof(unsigned long), reduce_operation);
|
||||||
|
break;
|
||||||
|
case DataType::FLOAT:
|
||||||
|
Accumulate(reinterpret_cast<float*>(&buffer[0]),
|
||||||
|
reinterpret_cast<float const*>(input.data()), buffer.size() / sizeof(float),
|
||||||
|
reduce_operation);
|
||||||
|
break;
|
||||||
|
case DataType::DOUBLE:
|
||||||
|
Accumulate(reinterpret_cast<double*>(&buffer[0]),
|
||||||
|
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
|
||||||
|
reduce_operation);
|
||||||
|
break;
|
||||||
|
case DataType::LONGLONG:
|
||||||
|
Accumulate(reinterpret_cast<long long*>(&buffer[0]),
|
||||||
|
reinterpret_cast<long long const*>(input.data()),
|
||||||
|
buffer.size() / sizeof(long long), reduce_operation);
|
||||||
|
break;
|
||||||
|
case DataType::ULONGLONG:
|
||||||
|
Accumulate(reinterpret_cast<unsigned long long*>(&buffer[0]),
|
||||||
|
reinterpret_cast<unsigned long long const*>(input.data()),
|
||||||
|
buffer.size() / sizeof(unsigned long long), reduce_operation);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid data type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class BroadcastFunctor {
|
||||||
|
public:
|
||||||
|
std::string const name{"Broadcast"};
|
||||||
|
|
||||||
|
void operator()(BroadcastRequest const* request, std::string& buffer) const {
|
||||||
|
if (request->rank() == request->root()) {
|
||||||
|
// Copy the send_buffer if this is the root.
|
||||||
|
buffer = request->send_buffer();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
grpc::Status FederatedService::Allgather(grpc::ServerContext* context,
|
||||||
|
AllgatherRequest const* request, AllgatherReply* reply) {
|
||||||
|
return Handle(request, reply, AllgatherFunctor{world_size_});
|
||||||
|
}
|
||||||
|
|
||||||
|
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
|
||||||
|
AllreduceRequest const* request, AllreduceReply* reply) {
|
||||||
|
return Handle(request, reply, AllreduceFunctor{});
|
||||||
|
}
|
||||||
|
|
||||||
|
grpc::Status FederatedService::Broadcast(grpc::ServerContext* context,
|
||||||
|
BroadcastRequest const* request, BroadcastReply* reply) {
|
||||||
|
return Handle(request, reply, BroadcastFunctor{});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Request, class Reply, class RequestFunctor>
|
||||||
|
grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
|
||||||
|
RequestFunctor const& functor) {
|
||||||
|
// Pass through if there is only 1 client.
|
||||||
|
if (world_size_ == 1) {
|
||||||
|
reply->set_receive_buffer(request->send_buffer());
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
|
||||||
|
auto const sequence_number = request->sequence_number();
|
||||||
|
auto const rank = request->rank();
|
||||||
|
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number";
|
||||||
|
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
|
||||||
|
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": handling request";
|
||||||
|
functor(request, buffer_);
|
||||||
|
received_++;
|
||||||
|
|
||||||
|
if (received_ == world_size_) {
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": all requests received";
|
||||||
|
reply->set_receive_buffer(buffer_);
|
||||||
|
sent_++;
|
||||||
|
lock.unlock();
|
||||||
|
cv_.notify_all();
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients";
|
||||||
|
cv_.wait(lock, [this] { return received_ == world_size_; });
|
||||||
|
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": sending reply";
|
||||||
|
reply->set_receive_buffer(buffer_);
|
||||||
|
sent_++;
|
||||||
|
|
||||||
|
if (sent_ == world_size_) {
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": all replies sent";
|
||||||
|
sent_ = 0;
|
||||||
|
received_ = 0;
|
||||||
|
buffer_.clear();
|
||||||
|
sequence_number_++;
|
||||||
|
lock.unlock();
|
||||||
|
cv_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ReadFile(char const* path) {
|
||||||
|
auto stream = std::ifstream(path);
|
||||||
|
std::ostringstream out;
|
||||||
|
out << stream.rdbuf();
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||||
|
char const* client_cert_file) {
|
||||||
|
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||||
|
FederatedService service{world_size};
|
||||||
|
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
auto options =
|
||||||
|
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
||||||
|
options.pem_root_certs = ReadFile(client_cert_file);
|
||||||
|
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
||||||
|
key.private_key = ReadFile(server_key_file);
|
||||||
|
key.cert_chain = ReadFile(server_cert_file);
|
||||||
|
options.pem_key_cert_pairs.push_back(key);
|
||||||
|
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||||
|
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
|
||||||
|
<< world_size;
|
||||||
|
|
||||||
|
server->Wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace federated
|
||||||
|
} // namespace xgboost
|
||||||
44
plugin/federated/federated_server.h
Normal file
44
plugin/federated/federated_server.h
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <federated.grpc.pb.h>
|
||||||
|
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace federated {
|
||||||
|
|
||||||
|
class FederatedService final : public Federated::Service {
|
||||||
|
public:
|
||||||
|
explicit FederatedService(int const world_size) : world_size_{world_size} {}
|
||||||
|
|
||||||
|
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||||
|
AllgatherReply* reply) override;
|
||||||
|
|
||||||
|
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
||||||
|
AllreduceReply* reply) override;
|
||||||
|
|
||||||
|
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
|
||||||
|
BroadcastReply* reply) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <class Request, class Reply, class RequestFunctor>
|
||||||
|
grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor);
|
||||||
|
|
||||||
|
int const world_size_;
|
||||||
|
int received_{};
|
||||||
|
int sent_{};
|
||||||
|
std::string buffer_{};
|
||||||
|
uint64_t sequence_number_{};
|
||||||
|
mutable std::mutex mutex_;
|
||||||
|
mutable std::condition_variable cv_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||||
|
char const* client_cert_file);
|
||||||
|
|
||||||
|
} // namespace federated
|
||||||
|
} // namespace xgboost
|
||||||
36
python-package/xgboost/federated.py
Normal file
36
python-package/xgboost/federated.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
"""XGBoost Federated Learning related API."""
|
||||||
|
|
||||||
|
from .core import _LIB, _check_call, c_str, build_info, XGBoostError
|
||||||
|
|
||||||
|
|
||||||
|
def run_federated_server(port: int,
|
||||||
|
world_size: int,
|
||||||
|
server_key_path: str,
|
||||||
|
server_cert_path: str,
|
||||||
|
client_cert_path: str) -> None:
|
||||||
|
"""Run the Federated Learning server.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
port : int
|
||||||
|
The port to listen on.
|
||||||
|
world_size: int
|
||||||
|
The number of federated workers.
|
||||||
|
server_key_path: str
|
||||||
|
Path to the server private key file.
|
||||||
|
server_cert_path: str
|
||||||
|
Path to the server certificate file.
|
||||||
|
client_cert_path: str
|
||||||
|
Path to the client certificate file.
|
||||||
|
"""
|
||||||
|
if build_info()['USE_FEDERATED']:
|
||||||
|
_check_call(_LIB.XGBRunFederatedServer(port,
|
||||||
|
world_size,
|
||||||
|
c_str(server_key_path),
|
||||||
|
c_str(server_cert_path),
|
||||||
|
c_str(client_cert_path)))
|
||||||
|
else:
|
||||||
|
raise XGBoostError(
|
||||||
|
"XGBoost needs to be built with the federated learning plugin "
|
||||||
|
"enabled in order to use this module"
|
||||||
|
)
|
||||||
@ -6,7 +6,9 @@ set(RABIT_SOURCES
|
|||||||
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
|
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
|
||||||
${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc)
|
${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc)
|
||||||
|
|
||||||
if (RABIT_BUILD_MPI)
|
if (PLUGIN_FEDERATED)
|
||||||
|
# Skip the engine if the Federated Learning plugin is enabled.
|
||||||
|
elseif (RABIT_BUILD_MPI)
|
||||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc)
|
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc)
|
||||||
elseif (RABIT_MOCK)
|
elseif (RABIT_MOCK)
|
||||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)
|
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)
|
||||||
|
|||||||
@ -28,6 +28,10 @@
|
|||||||
#include "../data/simple_dmatrix.h"
|
#include "../data/simple_dmatrix.h"
|
||||||
#include "../data/proxy_dmatrix.h"
|
#include "../data/proxy_dmatrix.h"
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
#include "../../plugin/federated/federated_server.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
using namespace xgboost; // NOLINT(*);
|
using namespace xgboost; // NOLINT(*);
|
||||||
|
|
||||||
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
|
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
|
||||||
@ -95,6 +99,12 @@ XGB_DLL int XGBuildInfo(char const **out) {
|
|||||||
info["DEBUG"] = Boolean{false};
|
info["DEBUG"] = Boolean{false};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
info["USE_FEDERATED"] = Boolean{true};
|
||||||
|
#else
|
||||||
|
info["USE_FEDERATED"] = Boolean{false};
|
||||||
|
#endif
|
||||||
|
|
||||||
XGBBuildInfoDevice(&info);
|
XGBBuildInfoDevice(&info);
|
||||||
|
|
||||||
auto &out_str = GlobalConfigAPIThreadLocalStore::Get()->ret_str;
|
auto &out_str = GlobalConfigAPIThreadLocalStore::Get()->ret_str;
|
||||||
@ -198,11 +208,15 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
|
|||||||
DMatrixHandle *out) {
|
DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
bool load_row_split = false;
|
bool load_row_split = false;
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||||
|
#else
|
||||||
if (rabit::IsDistributed()) {
|
if (rabit::IsDistributed()) {
|
||||||
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
||||||
<< "will split data among workers";
|
<< "will split data among workers";
|
||||||
load_row_split = true;
|
load_row_split = true;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -1342,5 +1356,14 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
|
||||||
|
char const *server_cert_path, char const *client_cert_path) {
|
||||||
|
API_BEGIN();
|
||||||
|
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// force link rabit
|
// force link rabit
|
||||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||||
|
|||||||
@ -18,6 +18,14 @@ if (NOT PLUGIN_UPDATER_ONEAPI)
|
|||||||
list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES})
|
list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES})
|
||||||
endif (NOT PLUGIN_UPDATER_ONEAPI)
|
endif (NOT PLUGIN_UPDATER_ONEAPI)
|
||||||
|
|
||||||
|
if (PLUGIN_FEDERATED)
|
||||||
|
target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated)
|
||||||
|
target_link_libraries(testxgboost PRIVATE federated_client)
|
||||||
|
else (PLUGIN_FEDERATED)
|
||||||
|
file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.cc")
|
||||||
|
list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES})
|
||||||
|
endif (PLUGIN_FEDERATED)
|
||||||
|
|
||||||
target_sources(testxgboost PRIVATE ${TEST_SOURCES} ${xgboost_SOURCE_DIR}/plugin/example/custom_obj.cc)
|
target_sources(testxgboost PRIVATE ${TEST_SOURCES} ${xgboost_SOURCE_DIR}/plugin/example/custom_obj.cc)
|
||||||
|
|
||||||
if (USE_CUDA AND PLUGIN_RMM)
|
if (USE_CUDA AND PLUGIN_RMM)
|
||||||
|
|||||||
130
tests/cpp/plugin/test_federated_server.cc
Normal file
130
tests/cpp/plugin/test_federated_server.cc
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017-2020 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <grpcpp/server_builder.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "federated_client.h"
|
||||||
|
#include "federated_server.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
class FederatedServerTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
static void VerifyAllgather(int rank) {
|
||||||
|
federated::FederatedClient client{kServerAddress, rank};
|
||||||
|
CheckAllgather(client, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void VerifyAllreduce(int rank) {
|
||||||
|
federated::FederatedClient client{kServerAddress, rank};
|
||||||
|
CheckAllreduce(client);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void VerifyBroadcast(int rank) {
|
||||||
|
federated::FederatedClient client{kServerAddress, rank};
|
||||||
|
CheckBroadcast(client, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void VerifyMixture(int rank) {
|
||||||
|
federated::FederatedClient client{kServerAddress, rank};
|
||||||
|
for (auto i = 0; i < 10; i++) {
|
||||||
|
CheckAllgather(client, rank);
|
||||||
|
CheckAllreduce(client);
|
||||||
|
CheckBroadcast(client, rank);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
server_thread_.reset(new std::thread([this] {
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
federated::FederatedService service{kWorldSize};
|
||||||
|
builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials());
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
server_ = builder.BuildAndStart();
|
||||||
|
server_->Wait();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
server_->Shutdown();
|
||||||
|
server_thread_->join();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void CheckAllgather(federated::FederatedClient& client, int rank) {
|
||||||
|
auto reply = client.Allgather("hello " + std::to_string(rank) + " ");
|
||||||
|
EXPECT_EQ(reply, "hello 0 hello 1 hello 2 ");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void CheckAllreduce(federated::FederatedClient& client) {
|
||||||
|
int data[] = {1, 2, 3, 4, 5};
|
||||||
|
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
||||||
|
auto reply = client.Allreduce(send_buffer, federated::INT, federated::SUM);
|
||||||
|
auto const* result = reinterpret_cast<int const*>(reply.data());
|
||||||
|
int expected[] = {3, 6, 9, 12, 15};
|
||||||
|
for (auto i = 0; i < 5; i++) {
|
||||||
|
EXPECT_EQ(result[i], expected[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void CheckBroadcast(federated::FederatedClient& client, int rank) {
|
||||||
|
std::string send_buffer{};
|
||||||
|
if (rank == 0) {
|
||||||
|
send_buffer = "hello broadcast";
|
||||||
|
}
|
||||||
|
auto reply = client.Broadcast(send_buffer, 0);
|
||||||
|
EXPECT_EQ(reply, "hello broadcast");
|
||||||
|
}
|
||||||
|
|
||||||
|
static int const kWorldSize{3};
|
||||||
|
static std::string const kServerAddress;
|
||||||
|
std::unique_ptr<std::thread> server_thread_;
|
||||||
|
std::unique_ptr<grpc::Server> server_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string const FederatedServerTest::kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp)
|
||||||
|
|
||||||
|
TEST_F(FederatedServerTest, Allgather) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank));
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedServerTest, Allreduce) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllreduce, rank));
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedServerTest, Broadcast) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread(&FederatedServerTest::VerifyBroadcast, rank));
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedServerTest, Mixture) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread(&FederatedServerTest::VerifyMixture, rank));
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xgboost
|
||||||
17
tests/distributed/runtests-federated.sh
Executable file
17
tests/distributed/runtests-federated.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
rm -f ./*.model* ./agaricus* ./*.pem
|
||||||
|
|
||||||
|
world_size=3
|
||||||
|
|
||||||
|
# Generate server and client certificates.
|
||||||
|
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
|
||||||
|
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
|
||||||
|
|
||||||
|
# Split train and test files manually to simulate a federated environment.
|
||||||
|
split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train-
|
||||||
|
split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test-
|
||||||
|
|
||||||
|
python test_federated.py ${world_size}
|
||||||
78
tests/distributed/test_federated.py
Normal file
78
tests/distributed/test_federated.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
import multiprocessing
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
import xgboost.federated
|
||||||
|
|
||||||
|
SERVER_KEY = 'server-key.pem'
|
||||||
|
SERVER_CERT = 'server-cert.pem'
|
||||||
|
CLIENT_KEY = 'client-key.pem'
|
||||||
|
CLIENT_CERT = 'client-cert.pem'
|
||||||
|
|
||||||
|
|
||||||
|
def run_server(port: int, world_size: int) -> None:
|
||||||
|
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
|
||||||
|
CLIENT_CERT)
|
||||||
|
|
||||||
|
|
||||||
|
def run_worker(port: int, world_size: int, rank: int) -> None:
|
||||||
|
# Always call this before using distributed module
|
||||||
|
rabit_env = [
|
||||||
|
f'federated_server_address=localhost:{port}',
|
||||||
|
f'federated_world_size={world_size}',
|
||||||
|
f'federated_rank={rank}',
|
||||||
|
f'federated_server_cert={SERVER_CERT}',
|
||||||
|
f'federated_client_key={CLIENT_KEY}',
|
||||||
|
f'federated_client_cert={CLIENT_CERT}'
|
||||||
|
]
|
||||||
|
xgb.rabit.init([e.encode() for e in rabit_env])
|
||||||
|
|
||||||
|
# Load file, file will not be sharded in federated mode.
|
||||||
|
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
||||||
|
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
|
||||||
|
|
||||||
|
# Specify parameters via map, definition are same as c++ version
|
||||||
|
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
||||||
|
|
||||||
|
# Specify validations set to watch performance
|
||||||
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
num_round = 20
|
||||||
|
|
||||||
|
# Run training, all the features in training API is available.
|
||||||
|
# Currently, this script only support calling train once for fault recovery purpose.
|
||||||
|
bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2)
|
||||||
|
|
||||||
|
# Save the model, only ask process 0 to save the model.
|
||||||
|
if xgb.rabit.get_rank() == 0:
|
||||||
|
bst.save_model("test.model.json")
|
||||||
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
|
# Notify the tracker all training has been successful
|
||||||
|
# This is only needed in distributed training.
|
||||||
|
xgb.rabit.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
def run_test() -> None:
|
||||||
|
port = 9091
|
||||||
|
world_size = int(sys.argv[1])
|
||||||
|
|
||||||
|
server = multiprocessing.Process(target=run_server, args=(port, world_size))
|
||||||
|
server.start()
|
||||||
|
time.sleep(1)
|
||||||
|
if not server.is_alive():
|
||||||
|
raise Exception("Error starting Federated Learning server")
|
||||||
|
|
||||||
|
workers = []
|
||||||
|
for rank in range(world_size):
|
||||||
|
worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank))
|
||||||
|
workers.append(worker)
|
||||||
|
worker.start()
|
||||||
|
for worker in workers:
|
||||||
|
worker.join()
|
||||||
|
server.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_test()
|
||||||
Loading…
x
Reference in New Issue
Block a user