[coll] Federated comm. (#9732)
This commit is contained in:
@@ -97,4 +97,29 @@ TEST(BitField, Clear) {
|
||||
TestBitFieldClear<RBitField8>(19);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BitField, CTZ) {
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0);
|
||||
ASSERT_EQ(cnt, sizeof(std::uint32_t) * 8);
|
||||
}
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0b00011100);
|
||||
ASSERT_EQ(cnt, 2);
|
||||
cnt = detail::TrailingZeroBitsImpl(0b00011100);
|
||||
ASSERT_EQ(cnt, 2);
|
||||
}
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0b00011101);
|
||||
ASSERT_EQ(cnt, 0);
|
||||
cnt = detail::TrailingZeroBitsImpl(0b00011101);
|
||||
ASSERT_EQ(cnt, 0);
|
||||
}
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0b1000000000000000);
|
||||
ASSERT_EQ(cnt, 15);
|
||||
cnt = detail::TrailingZeroBitsImpl(0b1000000000000000);
|
||||
ASSERT_EQ(cnt, 15);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -572,4 +572,31 @@ class BaseMGPUTest : public ::testing::Test {
|
||||
class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{};
|
||||
|
||||
inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }
|
||||
|
||||
/**
|
||||
* @brief poor man's gmock for message matching.
|
||||
*
|
||||
* @tparam Error The type of expected execption.
|
||||
*
|
||||
* @param submsg A substring of the actual error message.
|
||||
* @param fn The function that throws Error
|
||||
*/
|
||||
template <typename Error, typename Fn>
|
||||
void ExpectThrow(std::string submsg, Fn&& fn) {
|
||||
try {
|
||||
fn();
|
||||
} catch (Error const& exc) {
|
||||
auto actual = std::string{exc.what()};
|
||||
ASSERT_NE(actual.find(submsg), std::string::npos)
|
||||
<< "Expecting substring `" << submsg << "` from the error message."
|
||||
<< " Got:\n"
|
||||
<< actual << "\n";
|
||||
return;
|
||||
} catch (std::exception const& exc) {
|
||||
auto actual = exc.what();
|
||||
ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual;
|
||||
return;
|
||||
}
|
||||
ASSERT_TRUE(false) << "No exception is thrown";
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
84
tests/cpp/plugin/federated/test_federated_comm.cc
Normal file
84
tests/cpp/plugin/federated/test_federated_comm.cc
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_comm.h"
|
||||
#include "../../collective/net_test.h" // for SocketTest
|
||||
#include "../../helpers.h" // for ExpectThrow
|
||||
#include "test_worker.h" // for TestFederated
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class FederatedCommTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
||||
auto construct = [] {
|
||||
Json config{Object{}};
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = std::string("1");
|
||||
config["federated_rank"] = Integer(0);
|
||||
FederatedComm comm(config);
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
||||
auto construct = [] {
|
||||
Json config{Object{}};
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = 1;
|
||||
config["federated_rank"] = std::string("0");
|
||||
FederatedComm comm(config);
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = 6;
|
||||
config["federated_rank"] = 3;
|
||||
config["federated_server_address"] = String{"localhost:0"};
|
||||
FederatedComm comm{config};
|
||||
EXPECT_EQ(comm.World(), 6);
|
||||
EXPECT_EQ(comm.Rank(), 3);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, IsDistributed) {
|
||||
FederatedComm comm{"localhost", 0, 2, 1};
|
||||
EXPECT_TRUE(comm.IsDistributed());
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, InsecureTracker) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
||||
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = rank;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedComm comm{config};
|
||||
ASSERT_EQ(comm.Rank(), rank);
|
||||
ASSERT_EQ(comm.World(), n_workers);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
42
tests/cpp/plugin/federated/test_worker.h
Normal file
42
tests/cpp/plugin/federated/test_worker.h
Normal file
@@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for ms
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_tracker.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
template <typename WorkerFn>
|
||||
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{n_workers};
|
||||
FederatedTracker tracker{config};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
using namespace std::chrono_literals;
|
||||
while (tracker.Port() == 0) {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
}
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { fn(port, i); });
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
auto rc = tracker.Shutdown();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@@ -26,7 +26,7 @@ class ServerForTest {
|
||||
explicit ServerForTest(std::size_t world_size) {
|
||||
server_thread_.reset(new std::thread([this, world_size] {
|
||||
grpc::ServerBuilder builder;
|
||||
xgboost::federated::FederatedService service{world_size};
|
||||
xgboost::federated::FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
int selected_port;
|
||||
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
|
||||
builder.RegisterService(&service);
|
||||
|
||||
Reference in New Issue
Block a user