[coll] Federated comm. (#9732)
This commit is contained in:
parent
fa65cf6646
commit
80390e6cb6
@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h)
|
||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||
|
||||
# Rabit engine for Federated Learning.
|
||||
target_sources(objxgboost PRIVATE federated_server.cc)
|
||||
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
|
||||
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
|
||||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
||||
|
||||
114
plugin/federated/federated_comm.cc
Normal file
114
plugin/federated/federated_comm.cc
Normal file
@ -0,0 +1,114 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#include "federated_comm.h"
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <cstdlib> // for getenv
|
||||
#include <string> // for string, stoi
|
||||
|
||||
#include "../../src/common/common.h" // for Split
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world,
|
||||
std::int32_t rank, std::string const& server_cert,
|
||||
std::string const& client_key, std::string const& client_cert) {
|
||||
this->rank_ = rank;
|
||||
this->world_ = world;
|
||||
|
||||
this->tracker_.host = host;
|
||||
this->tracker_.port = port;
|
||||
this->tracker_.rank = rank;
|
||||
|
||||
CHECK_GE(world, 1) << "Invalid world size.";
|
||||
CHECK_GE(rank, 0) << "Invalid worker rank.";
|
||||
CHECK_LT(rank, world) << "Invalid worker rank.";
|
||||
|
||||
if (server_cert.empty()) {
|
||||
stub_ = [&] {
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
return federated::Federated::NewStub(
|
||||
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
|
||||
}();
|
||||
} else {
|
||||
stub_ = [&] {
|
||||
grpc::SslCredentialsOptions options;
|
||||
options.pem_root_certs = server_cert;
|
||||
options.pem_private_key = client_key;
|
||||
options.pem_cert_chain = client_cert;
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
|
||||
channel->WaitForConnected(
|
||||
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
|
||||
return federated::Federated::NewStub(channel);
|
||||
}();
|
||||
}
|
||||
}
|
||||
|
||||
FederatedComm::FederatedComm(Json const& config) {
|
||||
/**
|
||||
* Topology
|
||||
*/
|
||||
std::string server_address{};
|
||||
std::int32_t world_size{0};
|
||||
std::int32_t rank{-1};
|
||||
// Parse environment variables first.
|
||||
auto* value = std::getenv("FEDERATED_SERVER_ADDRESS");
|
||||
if (value != nullptr) {
|
||||
server_address = value;
|
||||
}
|
||||
value = std::getenv("FEDERATED_WORLD_SIZE");
|
||||
if (value != nullptr) {
|
||||
world_size = std::stoi(value);
|
||||
}
|
||||
value = std::getenv("FEDERATED_RANK");
|
||||
if (value != nullptr) {
|
||||
rank = std::stoi(value);
|
||||
}
|
||||
|
||||
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
|
||||
world_size =
|
||||
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
|
||||
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
|
||||
|
||||
auto parsed = common::Split(server_address, ':');
|
||||
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
|
||||
|
||||
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
|
||||
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
||||
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
|
||||
|
||||
/**
|
||||
* Certificates
|
||||
*/
|
||||
std::string server_cert{};
|
||||
std::string client_key{};
|
||||
std::string client_cert{};
|
||||
value = getenv("FEDERATED_SERVER_CERT_PATH");
|
||||
if (value != nullptr) {
|
||||
server_cert = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_KEY_PATH");
|
||||
if (value != nullptr) {
|
||||
client_key = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_CERT_PATH");
|
||||
if (value != nullptr) {
|
||||
client_cert = value;
|
||||
}
|
||||
|
||||
server_cert = OptionalArg<String>(config, "federated_server_cert_path", server_cert);
|
||||
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
|
||||
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);
|
||||
|
||||
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
|
||||
client_cert);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
53
plugin/federated/federated_comm.h
Normal file
53
plugin/federated/federated_comm.h
Normal file
@ -0,0 +1,53 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
#include <federated.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/comm.h" // for Comm
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedComm : public Comm {
|
||||
std::unique_ptr<federated::Federated::Stub> stub_;
|
||||
|
||||
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
|
||||
std::string const& server_cert, std::string const& client_key,
|
||||
std::string const& client_cert);
|
||||
|
||||
public:
|
||||
/**
|
||||
* @param config
|
||||
*
|
||||
* - federated_server_address: Tracker address
|
||||
* - federated_world_size: The number of workers
|
||||
* - federated_rank: Rank of federated worker
|
||||
* - federated_server_cert_path
|
||||
* - federated_client_key_path
|
||||
* - federated_client_cert_path
|
||||
*/
|
||||
explicit FederatedComm(Json const& config);
|
||||
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
|
||||
std::int32_t rank) {
|
||||
this->Init(host, port, world, rank, {}, {}, {});
|
||||
}
|
||||
~FederatedComm() override { stub_.reset(); }
|
||||
|
||||
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
|
||||
LOG(FATAL) << "peer to peer communication is not allowed for federated learning.";
|
||||
return nullptr;
|
||||
}
|
||||
[[nodiscard]] Result LogTracker(std::string msg) const override {
|
||||
LOG(CONSOLE) << msg;
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -4,12 +4,15 @@
|
||||
#include "federated_server.h"
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/server.h> // for Server
|
||||
#include <grpcpp/server_builder.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "../../src/collective/comm.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "../../src/common/json_utils.h"
|
||||
|
||||
namespace xgboost::federated {
|
||||
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
||||
@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
auto options =
|
||||
@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
|
||||
@ -1,18 +1,22 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
|
||||
#include "../../src/collective/in_memory_handler.h"
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost {
|
||||
namespace federated {
|
||||
|
||||
namespace xgboost::federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
|
||||
explicit FederatedService(std::int32_t world_size)
|
||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file);
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size);
|
||||
|
||||
} // namespace federated
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::federated
|
||||
|
||||
101
plugin/federated/federated_tracker.cc
Normal file
101
plugin/federated/federated_tracker.cc
Normal file
@ -0,0 +1,101 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include "federated_tracker.h"
|
||||
|
||||
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
|
||||
#include <grpcpp/server_builder.h> // for ServerBuilder
|
||||
|
||||
#include <chrono> // for ms
|
||||
#include <cstdint> // for int32_t
|
||||
#include <exception> // for exception
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string
|
||||
#include <thread> // for sleep_for
|
||||
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for RequiredArg
|
||||
#include "../../src/common/timer.h" // for Timer
|
||||
#include "federated_server.h" // for FederatedService
|
||||
|
||||
namespace xgboost::collective {
|
||||
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
|
||||
if (is_secure) {
|
||||
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
|
||||
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
|
||||
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
std::future<Result> FederatedTracker::Run() {
|
||||
return std::async([this]() {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
|
||||
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
|
||||
grpc::ServerBuilder builder;
|
||||
|
||||
if (this->server_cert_file_.empty()) {
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
if (this->port_ == 0) {
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
|
||||
} else {
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||
<< this->n_workers_;
|
||||
} else {
|
||||
auto options = grpc::SslServerCredentialsOptions(
|
||||
GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
||||
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_);
|
||||
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
||||
key.private_key = xgboost::common::ReadAll(server_key_path_);
|
||||
key.cert_chain = xgboost::common::ReadAll(server_cert_file_);
|
||||
options.pem_key_cert_pairs.push_back(key);
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
if (this->port_ == 0) {
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_);
|
||||
} else {
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
|
||||
<< n_workers_;
|
||||
}
|
||||
|
||||
try {
|
||||
server_->Wait();
|
||||
} catch (std::exception const& e) {
|
||||
return collective::Fail(std::string{e.what()});
|
||||
}
|
||||
return collective::Success();
|
||||
});
|
||||
}
|
||||
|
||||
FederatedTracker::~FederatedTracker() = default;
|
||||
|
||||
Result FederatedTracker::Shutdown() {
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
using namespace std::chrono_literals;
|
||||
while (!server_) {
|
||||
timer.Stop();
|
||||
auto ela = timer.ElapsedSeconds();
|
||||
if (ela > this->Timeout().count()) {
|
||||
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
|
||||
" seconds.");
|
||||
}
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
|
||||
try {
|
||||
server_->Shutdown();
|
||||
} catch (std::exception const& e) {
|
||||
return Fail("Failed to shutdown:" + std::string{e.what()});
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
41
plugin/federated/federated_tracker.h
Normal file
41
plugin/federated/federated_tracker.h
Normal file
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <federated.grpc.pb.h> // for Server
|
||||
|
||||
#include <future> // for future
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedTracker : public collective::Tracker {
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
std::string server_key_path_;
|
||||
std::string server_cert_file_;
|
||||
std::string client_cert_file_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief CTOR
|
||||
*
|
||||
* @param config Configuration, other than the base configuration from Tracker, we have:
|
||||
*
|
||||
* - federated_secure: bool whether this is a secure server.
|
||||
* - server_key_path: path to the key.
|
||||
* - server_cert_path: certificate path.
|
||||
* - client_cert_path: certificate path for client.
|
||||
*/
|
||||
explicit FederatedTracker(Json const& config);
|
||||
~FederatedTracker() override;
|
||||
std::future<Result> Run() override;
|
||||
// federated tracker do not provide initialization parameters, users have to provide it
|
||||
// themseleves.
|
||||
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
|
||||
[[nodiscard]] Result Shutdown();
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -50,6 +50,7 @@ class Tracker {
|
||||
[[nodiscard]] virtual std::future<Result> Run() = 0;
|
||||
[[nodiscard]] virtual Json WorkerArgs() const = 0;
|
||||
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
|
||||
[[nodiscard]] virtual std::int32_t Port() const { return port_; }
|
||||
};
|
||||
|
||||
class RabitTracker : public Tracker {
|
||||
@ -124,7 +125,6 @@ class RabitTracker : public Tracker {
|
||||
|
||||
std::future<Result> Run() override;
|
||||
|
||||
[[nodiscard]] std::int32_t Port() const { return port_; }
|
||||
[[nodiscard]] Json WorkerArgs() const override {
|
||||
Json args{Object{}};
|
||||
args["DMLC_TRACKER_URI"] = String{host_};
|
||||
|
||||
@ -97,4 +97,29 @@ TEST(BitField, Clear) {
|
||||
TestBitFieldClear<RBitField8>(19);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BitField, CTZ) {
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0);
|
||||
ASSERT_EQ(cnt, sizeof(std::uint32_t) * 8);
|
||||
}
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0b00011100);
|
||||
ASSERT_EQ(cnt, 2);
|
||||
cnt = detail::TrailingZeroBitsImpl(0b00011100);
|
||||
ASSERT_EQ(cnt, 2);
|
||||
}
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0b00011101);
|
||||
ASSERT_EQ(cnt, 0);
|
||||
cnt = detail::TrailingZeroBitsImpl(0b00011101);
|
||||
ASSERT_EQ(cnt, 0);
|
||||
}
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0b1000000000000000);
|
||||
ASSERT_EQ(cnt, 15);
|
||||
cnt = detail::TrailingZeroBitsImpl(0b1000000000000000);
|
||||
ASSERT_EQ(cnt, 15);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -572,4 +572,31 @@ class BaseMGPUTest : public ::testing::Test {
|
||||
class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{};
|
||||
|
||||
inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }
|
||||
|
||||
/**
|
||||
* @brief poor man's gmock for message matching.
|
||||
*
|
||||
* @tparam Error The type of expected execption.
|
||||
*
|
||||
* @param submsg A substring of the actual error message.
|
||||
* @param fn The function that throws Error
|
||||
*/
|
||||
template <typename Error, typename Fn>
|
||||
void ExpectThrow(std::string submsg, Fn&& fn) {
|
||||
try {
|
||||
fn();
|
||||
} catch (Error const& exc) {
|
||||
auto actual = std::string{exc.what()};
|
||||
ASSERT_NE(actual.find(submsg), std::string::npos)
|
||||
<< "Expecting substring `" << submsg << "` from the error message."
|
||||
<< " Got:\n"
|
||||
<< actual << "\n";
|
||||
return;
|
||||
} catch (std::exception const& exc) {
|
||||
auto actual = exc.what();
|
||||
ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual;
|
||||
return;
|
||||
}
|
||||
ASSERT_TRUE(false) << "No exception is thrown";
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
84
tests/cpp/plugin/federated/test_federated_comm.cc
Normal file
84
tests/cpp/plugin/federated/test_federated_comm.cc
Normal file
@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_comm.h"
|
||||
#include "../../collective/net_test.h" // for SocketTest
|
||||
#include "../../helpers.h" // for ExpectThrow
|
||||
#include "test_worker.h" // for TestFederated
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class FederatedCommTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
||||
auto construct = [] {
|
||||
Json config{Object{}};
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = std::string("1");
|
||||
config["federated_rank"] = Integer(0);
|
||||
FederatedComm comm(config);
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
||||
auto construct = [] {
|
||||
Json config{Object{}};
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = 1;
|
||||
config["federated_rank"] = std::string("0");
|
||||
FederatedComm comm(config);
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = 6;
|
||||
config["federated_rank"] = 3;
|
||||
config["federated_server_address"] = String{"localhost:0"};
|
||||
FederatedComm comm{config};
|
||||
EXPECT_EQ(comm.World(), 6);
|
||||
EXPECT_EQ(comm.Rank(), 3);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, IsDistributed) {
|
||||
FederatedComm comm{"localhost", 0, 2, 1};
|
||||
EXPECT_TRUE(comm.IsDistributed());
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, InsecureTracker) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
||||
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = rank;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedComm comm{config};
|
||||
ASSERT_EQ(comm.Rank(), rank);
|
||||
ASSERT_EQ(comm.World(), n_workers);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
42
tests/cpp/plugin/federated/test_worker.h
Normal file
42
tests/cpp/plugin/federated/test_worker.h
Normal file
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for ms
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_tracker.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
template <typename WorkerFn>
|
||||
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{n_workers};
|
||||
FederatedTracker tracker{config};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
using namespace std::chrono_literals;
|
||||
while (tracker.Port() == 0) {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
}
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { fn(port, i); });
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
auto rc = tracker.Shutdown();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@ -26,7 +26,7 @@ class ServerForTest {
|
||||
explicit ServerForTest(std::size_t world_size) {
|
||||
server_thread_.reset(new std::thread([this, world_size] {
|
||||
grpc::ServerBuilder builder;
|
||||
xgboost::federated::FederatedService service{world_size};
|
||||
xgboost::federated::FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
int selected_port;
|
||||
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
|
||||
builder.RegisterService(&service);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user