diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index be854d755..7c2cfa6fb 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h) target_link_libraries(federated_client INTERFACE federated_proto) # Rabit engine for Federated Learning. -target_sources(objxgboost PRIVATE federated_server.cc) +target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc) target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL") target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) diff --git a/plugin/federated/federated_comm.cc b/plugin/federated/federated_comm.cc new file mode 100644 index 000000000..4b51fd52d --- /dev/null +++ b/plugin/federated/federated_comm.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#include "federated_comm.h" + +#include + +#include // for int32_t +#include // for getenv +#include // for string, stoi + +#include "../../src/common/common.h" // for Split +#include "../../src/common/json_utils.h" // for OptionalArg +#include "xgboost/json.h" // for Json +#include "xgboost/logging.h" + +namespace xgboost::collective { +void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world, + std::int32_t rank, std::string const& server_cert, + std::string const& client_key, std::string const& client_cert) { + this->rank_ = rank; + this->world_ = world; + + this->tracker_.host = host; + this->tracker_.port = port; + this->tracker_.rank = rank; + + CHECK_GE(world, 1) << "Invalid world size."; + CHECK_GE(rank, 0) << "Invalid worker rank."; + CHECK_LT(rank, world) << "Invalid worker rank."; + + if (server_cert.empty()) { + stub_ = [&] { + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(std::numeric_limits::max()); + return federated::Federated::NewStub( + grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args)); + }(); + } else { + stub_ = [&] { + grpc::SslCredentialsOptions options; + options.pem_root_certs = server_cert; + options.pem_private_key = client_key; + options.pem_cert_chain = client_cert; + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(std::numeric_limits::max()); + auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args); + channel->WaitForConnected( + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN))); + return federated::Federated::NewStub(channel); + }(); + } +} + +FederatedComm::FederatedComm(Json const& config) { + /** + * Topology + */ + std::string server_address{}; + std::int32_t world_size{0}; + std::int32_t rank{-1}; + // Parse environment variables first. + auto* value = std::getenv("FEDERATED_SERVER_ADDRESS"); + if (value != nullptr) { + server_address = value; + } + value = std::getenv("FEDERATED_WORLD_SIZE"); + if (value != nullptr) { + world_size = std::stoi(value); + } + value = std::getenv("FEDERATED_RANK"); + if (value != nullptr) { + rank = std::stoi(value); + } + + server_address = OptionalArg(config, "federated_server_address", server_address); + world_size = + OptionalArg(config, "federated_world_size", static_cast(world_size)); + rank = OptionalArg(config, "federated_rank", static_cast(rank)); + + auto parsed = common::Split(server_address, ':'); + CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address; + + CHECK_NE(rank, -1) << "Parameter `federated_rank` is required"; + CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required."; + CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required."; + + /** + * Certificates + */ + std::string server_cert{}; + std::string client_key{}; + std::string client_cert{}; + value = getenv("FEDERATED_SERVER_CERT_PATH"); + if (value != nullptr) { + server_cert = value; + } + value = getenv("FEDERATED_CLIENT_KEY_PATH"); + if (value != nullptr) { + client_key = value; + } + value = getenv("FEDERATED_CLIENT_CERT_PATH"); + if (value != nullptr) { + client_cert = value; + } + + server_cert = OptionalArg(config, "federated_server_cert_path", server_cert); + client_key = OptionalArg(config, "federated_client_key_path", client_key); + client_cert = OptionalArg(config, "federated_client_cert_path", client_cert); + + this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key, + client_cert); +} +} // namespace xgboost::collective diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h new file mode 100644 index 000000000..8e6fe7d67 --- /dev/null +++ b/plugin/federated/federated_comm.h @@ -0,0 +1,53 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#pragma once + +#include +#include + +#include // for int32_t +#include // for unique_ptr +#include // for string + +#include "../../src/collective/comm.h" // for Comm +#include "../../src/common/json_utils.h" // for OptionalArg +#include "xgboost/json.h" + +namespace xgboost::collective { +class FederatedComm : public Comm { + std::unique_ptr stub_; + + void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank, + std::string const& server_cert, std::string const& client_key, + std::string const& client_cert); + + public: + /** + * @param config + * + * - federated_server_address: Tracker address + * - federated_world_size: The number of workers + * - federated_rank: Rank of federated worker + * - federated_server_cert_path + * - federated_client_key_path + * - federated_client_cert_path + */ + explicit FederatedComm(Json const& config); + explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world, + std::int32_t rank) { + this->Init(host, port, world, rank, {}, {}, {}); + } + ~FederatedComm() override { stub_.reset(); } + + [[nodiscard]] std::shared_ptr Chan(std::int32_t) const override { + LOG(FATAL) << "peer to peer communication is not allowed for federated learning."; + return nullptr; + } + [[nodiscard]] Result LogTracker(std::string msg) const override { + LOG(CONSOLE) << msg; + return Success(); + } + [[nodiscard]] bool IsFederated() const override { return true; } +}; +} // namespace xgboost::collective diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index ad6cf6022..9dd97c2e1 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -4,12 +4,15 @@ #include "federated_server.h" #include +#include // for Server #include #include #include +#include "../../src/collective/comm.h" #include "../../src/common/io.h" +#include "../../src/common/json_utils.h" namespace xgboost::federated { grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request, @@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest void RunServer(int port, std::size_t world_size, char const* server_key_file, char const* server_cert_file, char const* client_cert_file) { std::string const server_address = "0.0.0.0:" + std::to_string(port); - FederatedService service{world_size}; + FederatedService service{static_cast(world_size)}; grpc::ServerBuilder builder; auto options = @@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file, void RunInsecureServer(int port, std::size_t world_size) { std::string const server_address = "0.0.0.0:" + std::to_string(port); - FederatedService service{world_size}; + FederatedService service{static_cast(world_size)}; grpc::ServerBuilder builder; builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index 711ef5588..20f3149f9 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -1,18 +1,22 @@ -/*! - * Copyright 2022 XGBoost contributors +/** + * Copyright 2022-2023, XGBoost contributors */ #pragma once #include +#include // for int32_t +#include // for future + #include "../../src/collective/in_memory_handler.h" +#include "../../src/collective/tracker.h" // for Tracker +#include "xgboost/collective/result.h" // for Result -namespace xgboost { -namespace federated { - +namespace xgboost::federated { class FederatedService final : public Federated::Service { public: - explicit FederatedService(std::size_t const world_size) : handler_{world_size} {} + explicit FederatedService(std::int32_t world_size) + : handler_{static_cast(world_size)} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override; @@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file, char const* server_cert_file, char const* client_cert_file); void RunInsecureServer(int port, std::size_t world_size); - -} // namespace federated -} // namespace xgboost +} // namespace xgboost::federated diff --git a/plugin/federated/federated_tracker.cc b/plugin/federated/federated_tracker.cc new file mode 100644 index 000000000..3dad9d7ce --- /dev/null +++ b/plugin/federated/federated_tracker.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ +#include "federated_tracker.h" + +#include // for InsecureServerCredentials, ... +#include // for ServerBuilder + +#include // for ms +#include // for int32_t +#include // for exception +#include // for numeric_limits +#include // for string +#include // for sleep_for + +#include "../../src/common/io.h" // for ReadAll +#include "../../src/common/json_utils.h" // for RequiredArg +#include "../../src/common/timer.h" // for Timer +#include "federated_server.h" // for FederatedService + +namespace xgboost::collective { +FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} { + auto is_secure = RequiredArg(config, "federated_secure", __func__); + if (is_secure) { + server_key_path_ = RequiredArg(config, "server_key_path", __func__); + server_cert_file_ = RequiredArg(config, "server_cert_path", __func__); + client_cert_file_ = RequiredArg(config, "client_cert_path", __func__); + } +} + +std::future FederatedTracker::Run() { + return std::async([this]() { + std::string const server_address = "0.0.0.0:" + std::to_string(this->port_); + federated::FederatedService service{static_cast(this->n_workers_)}; + grpc::ServerBuilder builder; + + if (this->server_cert_file_.empty()) { + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); + if (this->port_ == 0) { + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); + } else { + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + } + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size " + << this->n_workers_; + } else { + auto options = grpc::SslServerCredentialsOptions( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_); + auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); + key.private_key = xgboost::common::ReadAll(server_key_path_); + key.cert_chain = xgboost::common::ReadAll(server_cert_file_); + options.pem_key_cert_pairs.push_back(key); + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); + if (this->port_ == 0) { + builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_); + } else { + builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); + } + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size " + << n_workers_; + } + + try { + server_->Wait(); + } catch (std::exception const& e) { + return collective::Fail(std::string{e.what()}); + } + return collective::Success(); + }); +} + +FederatedTracker::~FederatedTracker() = default; + +Result FederatedTracker::Shutdown() { + common::Timer timer; + timer.Start(); + using namespace std::chrono_literals; + while (!server_) { + timer.Stop(); + auto ela = timer.ElapsedSeconds(); + if (ela > this->Timeout().count()) { + return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) + + " seconds."); + } + std::this_thread::sleep_for(10ms); + } + + try { + server_->Shutdown(); + } catch (std::exception const& e) { + return Fail("Failed to shutdown:" + std::string{e.what()}); + } + + return Success(); +} +} // namespace xgboost::collective diff --git a/plugin/federated/federated_tracker.h b/plugin/federated/federated_tracker.h new file mode 100644 index 000000000..9043adb38 --- /dev/null +++ b/plugin/federated/federated_tracker.h @@ -0,0 +1,41 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ +#pragma once +#include // for Server + +#include // for future +#include // for unique_ptr +#include // for string + +#include "../../src/collective/tracker.h" // for Tracker +#include "xgboost/collective/result.h" // for Result +#include "xgboost/json.h" // for Json + +namespace xgboost::collective { +class FederatedTracker : public collective::Tracker { + std::unique_ptr server_; + std::string server_key_path_; + std::string server_cert_file_; + std::string client_cert_file_; + + public: + /** + * @brief CTOR + * + * @param config Configuration, other than the base configuration from Tracker, we have: + * + * - federated_secure: bool whether this is a secure server. + * - server_key_path: path to the key. + * - server_cert_path: certificate path. + * - client_cert_path: certificate path for client. + */ + explicit FederatedTracker(Json const& config); + ~FederatedTracker() override; + std::future Run() override; + // federated tracker do not provide initialization parameters, users have to provide it + // themseleves. + [[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; } + [[nodiscard]] Result Shutdown(); +}; +} // namespace xgboost::collective diff --git a/src/collective/tracker.h b/src/collective/tracker.h index 7bbee3c8d..f90373220 100644 --- a/src/collective/tracker.h +++ b/src/collective/tracker.h @@ -50,6 +50,7 @@ class Tracker { [[nodiscard]] virtual std::future Run() = 0; [[nodiscard]] virtual Json WorkerArgs() const = 0; [[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; } + [[nodiscard]] virtual std::int32_t Port() const { return port_; } }; class RabitTracker : public Tracker { @@ -124,7 +125,6 @@ class RabitTracker : public Tracker { std::future Run() override; - [[nodiscard]] std::int32_t Port() const { return port_; } [[nodiscard]] Json WorkerArgs() const override { Json args{Object{}}; args["DMLC_TRACKER_URI"] = String{host_}; diff --git a/tests/cpp/common/test_bitfield.cc b/tests/cpp/common/test_bitfield.cc index 902e69f85..564776642 100644 --- a/tests/cpp/common/test_bitfield.cc +++ b/tests/cpp/common/test_bitfield.cc @@ -97,4 +97,29 @@ TEST(BitField, Clear) { TestBitFieldClear(19); } } + +TEST(BitField, CTZ) { + { + auto cnt = TrailingZeroBits(0); + ASSERT_EQ(cnt, sizeof(std::uint32_t) * 8); + } + { + auto cnt = TrailingZeroBits(0b00011100); + ASSERT_EQ(cnt, 2); + cnt = detail::TrailingZeroBitsImpl(0b00011100); + ASSERT_EQ(cnt, 2); + } + { + auto cnt = TrailingZeroBits(0b00011101); + ASSERT_EQ(cnt, 0); + cnt = detail::TrailingZeroBitsImpl(0b00011101); + ASSERT_EQ(cnt, 0); + } + { + auto cnt = TrailingZeroBits(0b1000000000000000); + ASSERT_EQ(cnt, 15); + cnt = detail::TrailingZeroBitsImpl(0b1000000000000000); + ASSERT_EQ(cnt, 15); + } +} } // namespace xgboost diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 82a55450e..9adda8aed 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -572,4 +572,31 @@ class BaseMGPUTest : public ::testing::Test { class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{}; inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); } + +/** + * @brief poor man's gmock for message matching. + * + * @tparam Error The type of expected execption. + * + * @param submsg A substring of the actual error message. + * @param fn The function that throws Error + */ +template +void ExpectThrow(std::string submsg, Fn&& fn) { + try { + fn(); + } catch (Error const& exc) { + auto actual = std::string{exc.what()}; + ASSERT_NE(actual.find(submsg), std::string::npos) + << "Expecting substring `" << submsg << "` from the error message." + << " Got:\n" + << actual << "\n"; + return; + } catch (std::exception const& exc) { + auto actual = exc.what(); + ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual; + return; + } + ASSERT_TRUE(false) << "No exception is thrown"; +} } // namespace xgboost diff --git a/tests/cpp/plugin/federated/test_federated_comm.cc b/tests/cpp/plugin/federated/test_federated_comm.cc new file mode 100644 index 000000000..5bbde1bbb --- /dev/null +++ b/tests/cpp/plugin/federated/test_federated_comm.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ +#include + +#include // for string +#include // for thread + +#include "../../../../plugin/federated/federated_comm.h" +#include "../../collective/net_test.h" // for SocketTest +#include "../../helpers.h" // for ExpectThrow +#include "test_worker.h" // for TestFederated +#include "xgboost/json.h" // for Json + +namespace xgboost::collective { +namespace { +class FederatedCommTest : public SocketTest {}; +} // namespace + +TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) { + auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; }; + ExpectThrow("Invalid world size.", construct); +} + +TEST_F(FederatedCommTest, ThrowOnRankTooSmall) { + auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; }; + ExpectThrow("Invalid worker rank.", construct); +} + +TEST_F(FederatedCommTest, ThrowOnRankTooBig) { + auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; }; + ExpectThrow("Invalid worker rank.", construct); +} + +TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) { + auto construct = [] { + Json config{Object{}}; + config["federated_server_address"] = std::string("localhost:0"); + config["federated_world_size"] = std::string("1"); + config["federated_rank"] = Integer(0); + FederatedComm comm(config); + }; + ExpectThrow("got: `String`", construct); +} + +TEST_F(FederatedCommTest, ThrowOnRankNotInteger) { + auto construct = [] { + Json config{Object{}}; + config["federated_server_address"] = std::string("localhost:0"); + config["federated_world_size"] = 1; + config["federated_rank"] = std::string("0"); + FederatedComm comm(config); + }; + ExpectThrow("got: `String`", construct); +} + +TEST_F(FederatedCommTest, GetWorldSizeAndRank) { + Json config{Object{}}; + config["federated_world_size"] = 6; + config["federated_rank"] = 3; + config["federated_server_address"] = String{"localhost:0"}; + FederatedComm comm{config}; + EXPECT_EQ(comm.World(), 6); + EXPECT_EQ(comm.Rank(), 3); +} + +TEST_F(FederatedCommTest, IsDistributed) { + FederatedComm comm{"localhost", 0, 2, 1}; + EXPECT_TRUE(comm.IsDistributed()); +} + +TEST_F(FederatedCommTest, InsecureTracker) { + std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); + TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) { + Json config{Object{}}; + config["federated_world_size"] = n_workers; + config["federated_rank"] = rank; + config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); + FederatedComm comm{config}; + ASSERT_EQ(comm.Rank(), rank); + ASSERT_EQ(comm.World(), n_workers); + }); +} +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_worker.h b/tests/cpp/plugin/federated/test_worker.h new file mode 100644 index 000000000..719b4c343 --- /dev/null +++ b/tests/cpp/plugin/federated/test_worker.h @@ -0,0 +1,42 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ +#pragma once + +#include + +#include // for ms +#include // for thread + +#include "../../../../plugin/federated/federated_tracker.h" +#include "xgboost/json.h" // for Json + +namespace xgboost::collective { +template +void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { + Json config{Object()}; + config["federated_secure"] = Boolean{false}; + config["n_workers"] = Integer{n_workers}; + FederatedTracker tracker{config}; + auto fut = tracker.Run(); + + std::vector workers; + using namespace std::chrono_literals; + while (tracker.Port() == 0) { + std::this_thread::sleep_for(100ms); + } + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { fn(port, i); }); + } + + for (auto& t : workers) { + t.join(); + } + + auto rc = tracker.Shutdown(); + ASSERT_TRUE(rc.OK()) << rc.Report(); + ASSERT_TRUE(fut.get().OK()); +} +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index b756adefd..3dd0c3a1f 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2022-2023 XGBoost contributors +/** + * Copyright 2022-2023, XGBoost contributors */ #pragma once @@ -26,7 +26,7 @@ class ServerForTest { explicit ServerForTest(std::size_t world_size) { server_thread_.reset(new std::thread([this, world_size] { grpc::ServerBuilder builder; - xgboost::federated::FederatedService service{world_size}; + xgboost::federated::FederatedService service{static_cast(world_size)}; int selected_port; builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port); builder.RegisterService(&service);