[coll] Federated comm. (#9732)
This commit is contained in:
@@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h)
|
||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||
|
||||
# Rabit engine for Federated Learning.
|
||||
target_sources(objxgboost PRIVATE federated_server.cc)
|
||||
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
|
||||
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
|
||||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
||||
|
||||
114
plugin/federated/federated_comm.cc
Normal file
114
plugin/federated/federated_comm.cc
Normal file
@@ -0,0 +1,114 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#include "federated_comm.h"
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <cstdlib> // for getenv
|
||||
#include <string> // for string, stoi
|
||||
|
||||
#include "../../src/common/common.h" // for Split
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world,
|
||||
std::int32_t rank, std::string const& server_cert,
|
||||
std::string const& client_key, std::string const& client_cert) {
|
||||
this->rank_ = rank;
|
||||
this->world_ = world;
|
||||
|
||||
this->tracker_.host = host;
|
||||
this->tracker_.port = port;
|
||||
this->tracker_.rank = rank;
|
||||
|
||||
CHECK_GE(world, 1) << "Invalid world size.";
|
||||
CHECK_GE(rank, 0) << "Invalid worker rank.";
|
||||
CHECK_LT(rank, world) << "Invalid worker rank.";
|
||||
|
||||
if (server_cert.empty()) {
|
||||
stub_ = [&] {
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
return federated::Federated::NewStub(
|
||||
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
|
||||
}();
|
||||
} else {
|
||||
stub_ = [&] {
|
||||
grpc::SslCredentialsOptions options;
|
||||
options.pem_root_certs = server_cert;
|
||||
options.pem_private_key = client_key;
|
||||
options.pem_cert_chain = client_cert;
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
|
||||
channel->WaitForConnected(
|
||||
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
|
||||
return federated::Federated::NewStub(channel);
|
||||
}();
|
||||
}
|
||||
}
|
||||
|
||||
FederatedComm::FederatedComm(Json const& config) {
|
||||
/**
|
||||
* Topology
|
||||
*/
|
||||
std::string server_address{};
|
||||
std::int32_t world_size{0};
|
||||
std::int32_t rank{-1};
|
||||
// Parse environment variables first.
|
||||
auto* value = std::getenv("FEDERATED_SERVER_ADDRESS");
|
||||
if (value != nullptr) {
|
||||
server_address = value;
|
||||
}
|
||||
value = std::getenv("FEDERATED_WORLD_SIZE");
|
||||
if (value != nullptr) {
|
||||
world_size = std::stoi(value);
|
||||
}
|
||||
value = std::getenv("FEDERATED_RANK");
|
||||
if (value != nullptr) {
|
||||
rank = std::stoi(value);
|
||||
}
|
||||
|
||||
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
|
||||
world_size =
|
||||
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
|
||||
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
|
||||
|
||||
auto parsed = common::Split(server_address, ':');
|
||||
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
|
||||
|
||||
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
|
||||
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
||||
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
|
||||
|
||||
/**
|
||||
* Certificates
|
||||
*/
|
||||
std::string server_cert{};
|
||||
std::string client_key{};
|
||||
std::string client_cert{};
|
||||
value = getenv("FEDERATED_SERVER_CERT_PATH");
|
||||
if (value != nullptr) {
|
||||
server_cert = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_KEY_PATH");
|
||||
if (value != nullptr) {
|
||||
client_key = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_CERT_PATH");
|
||||
if (value != nullptr) {
|
||||
client_cert = value;
|
||||
}
|
||||
|
||||
server_cert = OptionalArg<String>(config, "federated_server_cert_path", server_cert);
|
||||
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
|
||||
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);
|
||||
|
||||
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
|
||||
client_cert);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
53
plugin/federated/federated_comm.h
Normal file
53
plugin/federated/federated_comm.h
Normal file
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
#include <federated.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/comm.h" // for Comm
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedComm : public Comm {
|
||||
std::unique_ptr<federated::Federated::Stub> stub_;
|
||||
|
||||
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
|
||||
std::string const& server_cert, std::string const& client_key,
|
||||
std::string const& client_cert);
|
||||
|
||||
public:
|
||||
/**
|
||||
* @param config
|
||||
*
|
||||
* - federated_server_address: Tracker address
|
||||
* - federated_world_size: The number of workers
|
||||
* - federated_rank: Rank of federated worker
|
||||
* - federated_server_cert_path
|
||||
* - federated_client_key_path
|
||||
* - federated_client_cert_path
|
||||
*/
|
||||
explicit FederatedComm(Json const& config);
|
||||
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
|
||||
std::int32_t rank) {
|
||||
this->Init(host, port, world, rank, {}, {}, {});
|
||||
}
|
||||
~FederatedComm() override { stub_.reset(); }
|
||||
|
||||
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
|
||||
LOG(FATAL) << "peer to peer communication is not allowed for federated learning.";
|
||||
return nullptr;
|
||||
}
|
||||
[[nodiscard]] Result LogTracker(std::string msg) const override {
|
||||
LOG(CONSOLE) << msg;
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@@ -4,12 +4,15 @@
|
||||
#include "federated_server.h"
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/server.h> // for Server
|
||||
#include <grpcpp/server_builder.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "../../src/collective/comm.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "../../src/common/json_utils.h"
|
||||
|
||||
namespace xgboost::federated {
|
||||
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
||||
@@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
auto options =
|
||||
@@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
|
||||
#include "../../src/collective/in_memory_handler.h"
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost {
|
||||
namespace federated {
|
||||
|
||||
namespace xgboost::federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
|
||||
explicit FederatedService(std::int32_t world_size)
|
||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
@@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file);
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size);
|
||||
|
||||
} // namespace federated
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::federated
|
||||
|
||||
101
plugin/federated/federated_tracker.cc
Normal file
101
plugin/federated/federated_tracker.cc
Normal file
@@ -0,0 +1,101 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include "federated_tracker.h"
|
||||
|
||||
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
|
||||
#include <grpcpp/server_builder.h> // for ServerBuilder
|
||||
|
||||
#include <chrono> // for ms
|
||||
#include <cstdint> // for int32_t
|
||||
#include <exception> // for exception
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string
|
||||
#include <thread> // for sleep_for
|
||||
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for RequiredArg
|
||||
#include "../../src/common/timer.h" // for Timer
|
||||
#include "federated_server.h" // for FederatedService
|
||||
|
||||
namespace xgboost::collective {
|
||||
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
|
||||
if (is_secure) {
|
||||
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
|
||||
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
|
||||
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
std::future<Result> FederatedTracker::Run() {
|
||||
return std::async([this]() {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
|
||||
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
|
||||
grpc::ServerBuilder builder;
|
||||
|
||||
if (this->server_cert_file_.empty()) {
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
if (this->port_ == 0) {
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
|
||||
} else {
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||
<< this->n_workers_;
|
||||
} else {
|
||||
auto options = grpc::SslServerCredentialsOptions(
|
||||
GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
||||
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_);
|
||||
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
||||
key.private_key = xgboost::common::ReadAll(server_key_path_);
|
||||
key.cert_chain = xgboost::common::ReadAll(server_cert_file_);
|
||||
options.pem_key_cert_pairs.push_back(key);
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
if (this->port_ == 0) {
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_);
|
||||
} else {
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
|
||||
<< n_workers_;
|
||||
}
|
||||
|
||||
try {
|
||||
server_->Wait();
|
||||
} catch (std::exception const& e) {
|
||||
return collective::Fail(std::string{e.what()});
|
||||
}
|
||||
return collective::Success();
|
||||
});
|
||||
}
|
||||
|
||||
FederatedTracker::~FederatedTracker() = default;
|
||||
|
||||
Result FederatedTracker::Shutdown() {
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
using namespace std::chrono_literals;
|
||||
while (!server_) {
|
||||
timer.Stop();
|
||||
auto ela = timer.ElapsedSeconds();
|
||||
if (ela > this->Timeout().count()) {
|
||||
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
|
||||
" seconds.");
|
||||
}
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
|
||||
try {
|
||||
server_->Shutdown();
|
||||
} catch (std::exception const& e) {
|
||||
return Fail("Failed to shutdown:" + std::string{e.what()});
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
41
plugin/federated/federated_tracker.h
Normal file
41
plugin/federated/federated_tracker.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <federated.grpc.pb.h> // for Server
|
||||
|
||||
#include <future> // for future
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedTracker : public collective::Tracker {
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
std::string server_key_path_;
|
||||
std::string server_cert_file_;
|
||||
std::string client_cert_file_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief CTOR
|
||||
*
|
||||
* @param config Configuration, other than the base configuration from Tracker, we have:
|
||||
*
|
||||
* - federated_secure: bool whether this is a secure server.
|
||||
* - server_key_path: path to the key.
|
||||
* - server_cert_path: certificate path.
|
||||
* - client_cert_path: certificate path for client.
|
||||
*/
|
||||
explicit FederatedTracker(Json const& config);
|
||||
~FederatedTracker() override;
|
||||
std::future<Result> Run() override;
|
||||
// federated tracker do not provide initialization parameters, users have to provide it
|
||||
// themseleves.
|
||||
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
|
||||
[[nodiscard]] Result Shutdown();
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
Reference in New Issue
Block a user