[coll] Federated comm. (#9732)

This commit is contained in:
Jiaming Yuan 2023-10-31 02:39:55 +08:00 committed by GitHub
parent fa65cf6646
commit 80390e6cb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 508 additions and 16 deletions

View File

@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)
# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_server.cc)
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)

View File

@ -0,0 +1,114 @@
/**
* Copyright 2023, XGBoost contributors
*/
#include "federated_comm.h"
#include <grpcpp/grpcpp.h>
#include <cstdint> // for int32_t
#include <cstdlib> // for getenv
#include <string> // for string, stoi
#include "../../src/common/common.h" // for Split
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
#include "xgboost/logging.h"
namespace xgboost::collective {
void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank, std::string const& server_cert,
std::string const& client_key, std::string const& client_cert) {
this->rank_ = rank;
this->world_ = world;
this->tracker_.host = host;
this->tracker_.port = port;
this->tracker_.rank = rank;
CHECK_GE(world, 1) << "Invalid world size.";
CHECK_GE(rank, 0) << "Invalid worker rank.";
CHECK_LT(rank, world) << "Invalid worker rank.";
if (server_cert.empty()) {
stub_ = [&] {
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
return federated::Federated::NewStub(
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
}();
} else {
stub_ = [&] {
grpc::SslCredentialsOptions options;
options.pem_root_certs = server_cert;
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
return federated::Federated::NewStub(channel);
}();
}
}
FederatedComm::FederatedComm(Json const& config) {
/**
* Topology
*/
std::string server_address{};
std::int32_t world_size{0};
std::int32_t rank{-1};
// Parse environment variables first.
auto* value = std::getenv("FEDERATED_SERVER_ADDRESS");
if (value != nullptr) {
server_address = value;
}
value = std::getenv("FEDERATED_WORLD_SIZE");
if (value != nullptr) {
world_size = std::stoi(value);
}
value = std::getenv("FEDERATED_RANK");
if (value != nullptr) {
rank = std::stoi(value);
}
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
world_size =
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
auto parsed = common::Split(server_address, ':');
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
/**
* Certificates
*/
std::string server_cert{};
std::string client_key{};
std::string client_cert{};
value = getenv("FEDERATED_SERVER_CERT_PATH");
if (value != nullptr) {
server_cert = value;
}
value = getenv("FEDERATED_CLIENT_KEY_PATH");
if (value != nullptr) {
client_key = value;
}
value = getenv("FEDERATED_CLIENT_CERT_PATH");
if (value != nullptr) {
client_cert = value;
}
server_cert = OptionalArg<String>(config, "federated_server_cert_path", server_cert);
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
client_cert);
}
} // namespace xgboost::collective

View File

@ -0,0 +1,53 @@
/**
* Copyright 2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <federated.pb.h>
#include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <string> // for string
#include "../../src/collective/comm.h" // for Comm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"
namespace xgboost::collective {
class FederatedComm : public Comm {
std::unique_ptr<federated::Federated::Stub> stub_;
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
std::string const& client_cert);
public:
/**
* @param config
*
* - federated_server_address: Tracker address
* - federated_world_size: The number of workers
* - federated_rank: Rank of federated worker
* - federated_server_cert_path
* - federated_client_key_path
* - federated_client_cert_path
*/
explicit FederatedComm(Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
~FederatedComm() override { stub_.reset(); }
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
LOG(FATAL) << "peer to peer communication is not allowed for federated learning.";
return nullptr;
}
[[nodiscard]] Result LogTracker(std::string msg) const override {
LOG(CONSOLE) << msg;
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
};
} // namespace xgboost::collective

View File

@ -4,12 +4,15 @@
#include "federated_server.h"
#include <grpcpp/grpcpp.h>
#include <grpcpp/server.h> // for Server
#include <grpcpp/server_builder.h>
#include <xgboost/logging.h>
#include <sstream>
#include "../../src/collective/comm.h"
#include "../../src/common/io.h"
#include "../../src/common/json_utils.h"
namespace xgboost::federated {
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
FederatedService service{static_cast<std::int32_t>(world_size)};
grpc::ServerBuilder builder;
auto options =
@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
void RunInsecureServer(int port, std::size_t world_size) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
FederatedService service{static_cast<std::int32_t>(world_size)};
grpc::ServerBuilder builder;
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

View File

@ -1,18 +1,22 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <cstdint> // for int32_t
#include <future> // for future
#include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
namespace xgboost {
namespace federated {
namespace xgboost::federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file);
void RunInsecureServer(int port, std::size_t world_size);
} // namespace federated
} // namespace xgboost
} // namespace xgboost::federated

View File

@ -0,0 +1,101 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include "federated_tracker.h"
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
#include <grpcpp/server_builder.h> // for ServerBuilder
#include <chrono> // for ms
#include <cstdint> // for int32_t
#include <exception> // for exception
#include <limits> // for numeric_limits
#include <string> // for string
#include <thread> // for sleep_for
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for RequiredArg
#include "../../src/common/timer.h" // for Timer
#include "federated_server.h" // for FederatedService
namespace xgboost::collective {
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
if (is_secure) {
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
}
}
std::future<Result> FederatedTracker::Run() {
return std::async([this]() {
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
grpc::ServerBuilder builder;
if (this->server_cert_file_.empty()) {
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
if (this->port_ == 0) {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
} else {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< this->n_workers_;
} else {
auto options = grpc::SslServerCredentialsOptions(
GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = xgboost::common::ReadAll(server_key_path_);
key.cert_chain = xgboost::common::ReadAll(server_cert_file_);
options.pem_key_cert_pairs.push_back(key);
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
if (this->port_ == 0) {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_);
} else {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
<< n_workers_;
}
try {
server_->Wait();
} catch (std::exception const& e) {
return collective::Fail(std::string{e.what()});
}
return collective::Success();
});
}
FederatedTracker::~FederatedTracker() = default;
Result FederatedTracker::Shutdown() {
common::Timer timer;
timer.Start();
using namespace std::chrono_literals;
while (!server_) {
timer.Stop();
auto ela = timer.ElapsedSeconds();
if (ela > this->Timeout().count()) {
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
std::this_thread::sleep_for(10ms);
}
try {
server_->Shutdown();
} catch (std::exception const& e) {
return Fail("Failed to shutdown:" + std::string{e.what()});
}
return Success();
}
} // namespace xgboost::collective

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h> // for Server
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
class FederatedTracker : public collective::Tracker {
std::unique_ptr<grpc::Server> server_;
std::string server_key_path_;
std::string server_cert_file_;
std::string client_cert_file_;
public:
/**
* @brief CTOR
*
* @param config Configuration, other than the base configuration from Tracker, we have:
*
* - federated_secure: bool whether this is a secure server.
* - server_key_path: path to the key.
* - server_cert_path: certificate path.
* - client_cert_path: certificate path for client.
*/
explicit FederatedTracker(Json const& config);
~FederatedTracker() override;
std::future<Result> Run() override;
// federated tracker do not provide initialization parameters, users have to provide it
// themseleves.
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
[[nodiscard]] Result Shutdown();
};
} // namespace xgboost::collective

View File

@ -50,6 +50,7 @@ class Tracker {
[[nodiscard]] virtual std::future<Result> Run() = 0;
[[nodiscard]] virtual Json WorkerArgs() const = 0;
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
[[nodiscard]] virtual std::int32_t Port() const { return port_; }
};
class RabitTracker : public Tracker {
@ -124,7 +125,6 @@ class RabitTracker : public Tracker {
std::future<Result> Run() override;
[[nodiscard]] std::int32_t Port() const { return port_; }
[[nodiscard]] Json WorkerArgs() const override {
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};

View File

@ -97,4 +97,29 @@ TEST(BitField, Clear) {
TestBitFieldClear<RBitField8>(19);
}
}
TEST(BitField, CTZ) {
{
auto cnt = TrailingZeroBits(0);
ASSERT_EQ(cnt, sizeof(std::uint32_t) * 8);
}
{
auto cnt = TrailingZeroBits(0b00011100);
ASSERT_EQ(cnt, 2);
cnt = detail::TrailingZeroBitsImpl(0b00011100);
ASSERT_EQ(cnt, 2);
}
{
auto cnt = TrailingZeroBits(0b00011101);
ASSERT_EQ(cnt, 0);
cnt = detail::TrailingZeroBitsImpl(0b00011101);
ASSERT_EQ(cnt, 0);
}
{
auto cnt = TrailingZeroBits(0b1000000000000000);
ASSERT_EQ(cnt, 15);
cnt = detail::TrailingZeroBitsImpl(0b1000000000000000);
ASSERT_EQ(cnt, 15);
}
}
} // namespace xgboost

View File

@ -572,4 +572,31 @@ class BaseMGPUTest : public ::testing::Test {
class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{};
inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }
/**
* @brief poor man's gmock for message matching.
*
* @tparam Error The type of expected execption.
*
* @param submsg A substring of the actual error message.
* @param fn The function that throws Error
*/
template <typename Error, typename Fn>
void ExpectThrow(std::string submsg, Fn&& fn) {
try {
fn();
} catch (Error const& exc) {
auto actual = std::string{exc.what()};
ASSERT_NE(actual.find(submsg), std::string::npos)
<< "Expecting substring `" << submsg << "` from the error message."
<< " Got:\n"
<< actual << "\n";
return;
} catch (std::exception const& exc) {
auto actual = exc.what();
ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual;
return;
}
ASSERT_TRUE(false) << "No exception is thrown";
}
} // namespace xgboost

View File

@ -0,0 +1,84 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <string> // for string
#include <thread> // for thread
#include "../../../../plugin/federated/federated_comm.h"
#include "../../collective/net_test.h" // for SocketTest
#include "../../helpers.h" // for ExpectThrow
#include "test_worker.h" // for TestFederated
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
namespace {
class FederatedCommTest : public SocketTest {};
} // namespace
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
}
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
}
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; };
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
}
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
auto construct = [] {
Json config{Object{}};
config["federated_server_address"] = std::string("localhost:0");
config["federated_world_size"] = std::string("1");
config["federated_rank"] = Integer(0);
FederatedComm comm(config);
};
ExpectThrow<dmlc::Error>("got: `String`", construct);
}
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
auto construct = [] {
Json config{Object{}};
config["federated_server_address"] = std::string("localhost:0");
config["federated_world_size"] = 1;
config["federated_rank"] = std::string("0");
FederatedComm comm(config);
};
ExpectThrow<dmlc::Error>("got: `String`", construct);
}
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
Json config{Object{}};
config["federated_world_size"] = 6;
config["federated_rank"] = 3;
config["federated_server_address"] = String{"localhost:0"};
FederatedComm comm{config};
EXPECT_EQ(comm.World(), 6);
EXPECT_EQ(comm.Rank(), 3);
}
TEST_F(FederatedCommTest, IsDistributed) {
FederatedComm comm{"localhost", 0, 2, 1};
EXPECT_TRUE(comm.IsDistributed());
}
TEST_F(FederatedCommTest, InsecureTracker) {
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
Json config{Object{}};
config["federated_world_size"] = n_workers;
config["federated_rank"] = rank;
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
FederatedComm comm{config};
ASSERT_EQ(comm.Rank(), rank);
ASSERT_EQ(comm.World(), n_workers);
});
}
} // namespace xgboost::collective

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <gtest/gtest.h>
#include <chrono> // for ms
#include <thread> // for thread
#include "../../../../plugin/federated/federated_tracker.h"
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
template <typename WorkerFn>
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
Json config{Object()};
config["federated_secure"] = Boolean{false};
config["n_workers"] = Integer{n_workers};
FederatedTracker tracker{config};
auto fut = tracker.Run();
std::vector<std::thread> workers;
using namespace std::chrono_literals;
while (tracker.Port() == 0) {
std::this_thread::sleep_for(100ms);
}
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] { fn(port, i); });
}
for (auto& t : workers) {
t.join();
}
auto rc = tracker.Shutdown();
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_TRUE(fut.get().OK());
}
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2022-2023 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
@ -26,7 +26,7 @@ class ServerForTest {
explicit ServerForTest(std::size_t world_size) {
server_thread_.reset(new std::thread([this, world_size] {
grpc::ServerBuilder builder;
xgboost::federated::FederatedService service{world_size};
xgboost::federated::FederatedService service{static_cast<std::int32_t>(world_size)};
int selected_port;
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
builder.RegisterService(&service);