Convert federated learner test into test suite. (#9018)
* Convert federated learner test into test suite. - Add specialization to learning to rank.
This commit is contained in:
parent
2c8d735cb3
commit
fe9dff339c
32
tests/cpp/objective_helpers.h
Normal file
32
tests/cpp/objective_helpers.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2023, XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <dmlc/registry.h> // for Registry
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/objective.h> // for ObjFunctionReg
|
||||||
|
|
||||||
|
#include <algorithm> // for transform
|
||||||
|
#include <iterator> // for back_insert_iterator, back_inserter
|
||||||
|
#include <string> // for string
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
inline auto MakeObjNamesForTest() {
|
||||||
|
auto list = ::dmlc::Registry<::xgboost::ObjFunctionReg>::List();
|
||||||
|
std::vector<std::string> names;
|
||||||
|
std::transform(list.cbegin(), list.cend(), std::back_inserter(names),
|
||||||
|
[](auto const* entry) { return entry->name; });
|
||||||
|
return names;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ParamType>
|
||||||
|
inline std::string ObjTestNameGenerator(const ::testing::TestParamInfo<ParamType>& info) {
|
||||||
|
auto name = std::string{info.param};
|
||||||
|
// Name must be a valid c++ symbol
|
||||||
|
auto it = std::find(name.cbegin(), name.cend(), ':');
|
||||||
|
if (it != name.cend()) {
|
||||||
|
name[std::distance(name.cbegin(), it)] = '_';
|
||||||
|
}
|
||||||
|
return name;
|
||||||
|
};
|
||||||
|
} // namespace xgboost
|
||||||
@ -8,6 +8,7 @@
|
|||||||
#include <xgboost/json.h>
|
#include <xgboost/json.h>
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <thread> // for thread, sleep_for
|
||||||
|
|
||||||
#include "../../../plugin/federated/federated_server.h"
|
#include "../../../plugin/federated/federated_server.h"
|
||||||
#include "../../../src/collective/communicator-inl.h"
|
#include "../../../src/collective/communicator-inl.h"
|
||||||
@ -33,13 +34,17 @@ inline std::string GetServerAddress() {
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
class BaseFederatedTest : public ::testing::Test {
|
class ServerForTest {
|
||||||
protected:
|
std::string server_address_;
|
||||||
void SetUp() override {
|
std::unique_ptr<std::thread> server_thread_;
|
||||||
|
std::unique_ptr<grpc::Server> server_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit ServerForTest(std::int32_t world_size) {
|
||||||
server_address_ = GetServerAddress();
|
server_address_ = GetServerAddress();
|
||||||
server_thread_.reset(new std::thread([this] {
|
server_thread_.reset(new std::thread([this, world_size] {
|
||||||
grpc::ServerBuilder builder;
|
grpc::ServerBuilder builder;
|
||||||
xgboost::federated::FederatedService service{kWorldSize};
|
xgboost::federated::FederatedService service{world_size};
|
||||||
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
|
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
|
||||||
builder.RegisterService(&service);
|
builder.RegisterService(&service);
|
||||||
server_ = builder.BuildAndStart();
|
server_ = builder.BuildAndStart();
|
||||||
@ -47,15 +52,21 @@ class BaseFederatedTest : public ::testing::Test {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TearDown() override {
|
~ServerForTest() {
|
||||||
server_->Shutdown();
|
server_->Shutdown();
|
||||||
server_thread_->join();
|
server_thread_->join();
|
||||||
}
|
}
|
||||||
|
auto Address() const { return server_address_; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class BaseFederatedTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
||||||
|
|
||||||
|
void TearDown() override { server_.reset(nullptr); }
|
||||||
|
|
||||||
static int const kWorldSize{3};
|
static int const kWorldSize{3};
|
||||||
std::string server_address_;
|
std::unique_ptr<ServerForTest> server_;
|
||||||
std::unique_ptr<std::thread> server_thread_;
|
|
||||||
std::unique_ptr<grpc::Server> server_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Function, typename... Args>
|
template <typename Function, typename... Args>
|
||||||
|
|||||||
@ -29,7 +29,7 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) {
|
|||||||
TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
|
TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back([rank, server_address = server_address_] {
|
threads.emplace_back([rank, server_address = server_->Address()] {
|
||||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||||
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
|
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
|
||||||
DeviceCommunicatorAdapter adapter{0, &comm};
|
DeviceCommunicatorAdapter adapter{0, &comm};
|
||||||
@ -52,7 +52,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
|
|||||||
TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
|
TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back([rank, server_address = server_address_] {
|
threads.emplace_back([rank, server_address = server_->Address()] {
|
||||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||||
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
|
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
|
||||||
DeviceCommunicatorAdapter adapter{0, &comm};
|
DeviceCommunicatorAdapter adapter{0, &comm};
|
||||||
|
|||||||
@ -92,7 +92,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
|
|||||||
config["federated_server_address"] = server_address;
|
config["federated_server_address"] = server_address;
|
||||||
config["federated_world_size"] = std::string("1");
|
config["federated_world_size"] = std::string("1");
|
||||||
config["federated_rank"] = Integer(0);
|
config["federated_rank"] = Integer(0);
|
||||||
auto *comm = FederatedCommunicator::Create(config);
|
FederatedCommunicator::Create(config);
|
||||||
};
|
};
|
||||||
EXPECT_THROW(construct(), dmlc::Error);
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
}
|
}
|
||||||
@ -104,7 +104,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
|
|||||||
config["federated_server_address"] = server_address;
|
config["federated_server_address"] = server_address;
|
||||||
config["federated_world_size"] = 1;
|
config["federated_world_size"] = 1;
|
||||||
config["federated_rank"] = std::string("0");
|
config["federated_rank"] = std::string("0");
|
||||||
auto *comm = FederatedCommunicator::Create(config);
|
FederatedCommunicator::Create(config);
|
||||||
};
|
};
|
||||||
EXPECT_THROW(construct(), dmlc::Error);
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
}
|
}
|
||||||
@ -125,7 +125,7 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
|
|||||||
TEST_F(FederatedCommunicatorTest, Allgather) {
|
TEST_F(FederatedCommunicatorTest, Allgather) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_);
|
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_->Address());
|
||||||
}
|
}
|
||||||
for (auto &thread : threads) {
|
for (auto &thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -135,7 +135,7 @@ TEST_F(FederatedCommunicatorTest, Allgather) {
|
|||||||
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_);
|
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_->Address());
|
||||||
}
|
}
|
||||||
for (auto &thread : threads) {
|
for (auto &thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -145,7 +145,7 @@ TEST_F(FederatedCommunicatorTest, Allreduce) {
|
|||||||
TEST_F(FederatedCommunicatorTest, Broadcast) {
|
TEST_F(FederatedCommunicatorTest, Broadcast) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_);
|
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_->Address());
|
||||||
}
|
}
|
||||||
for (auto &thread : threads) {
|
for (auto &thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
|
|||||||
@ -38,8 +38,8 @@ void VerifyLoadUri() {
|
|||||||
auto index = 0;
|
auto index = 0;
|
||||||
int offsets[] = {0, 8, 17};
|
int offsets[] = {0, 8, 17};
|
||||||
int offset = offsets[rank];
|
int offset = offsets[rank];
|
||||||
for (auto row = 0; row < kRows; row++) {
|
for (std::size_t row = 0; row < kRows; row++) {
|
||||||
for (auto col = 0; col < kCols; col++) {
|
for (std::size_t col = 0; col < kCols; col++) {
|
||||||
EXPECT_EQ(entries[index].index, col + offset);
|
EXPECT_EQ(entries[index].index, col + offset);
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
@ -48,6 +48,6 @@ void VerifyLoadUri() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedDataTest, LoadUri) {
|
TEST_F(FederatedDataTest, LoadUri) {
|
||||||
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyLoadUri);
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLoadUri);
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -8,13 +8,34 @@
|
|||||||
|
|
||||||
#include "../../../plugin/federated/federated_server.h"
|
#include "../../../plugin/federated/federated_server.h"
|
||||||
#include "../../../src/collective/communicator-inl.h"
|
#include "../../../src/collective/communicator-inl.h"
|
||||||
|
#include "../../../src/common/linalg_op.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
#include "../objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator
|
||||||
#include "helpers.h"
|
#include "helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
namespace {
|
||||||
|
auto MakeModel(std::string objective, std::shared_ptr<DMatrix> dmat) {
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||||
|
learner->SetParam("tree_method", "approx");
|
||||||
|
learner->SetParam("objective", objective);
|
||||||
|
if (objective.find("quantile") != std::string::npos) {
|
||||||
|
learner->SetParam("quantile_alpha", "0.5");
|
||||||
|
}
|
||||||
|
if (objective.find("multi") != std::string::npos) {
|
||||||
|
learner->SetParam("num_class", "3");
|
||||||
|
}
|
||||||
|
learner->UpdateOneIter(0, dmat);
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
|
||||||
void VerifyObjectives(size_t rows, size_t cols, std::vector<float> const &expected_base_scores,
|
Json model{Object{}};
|
||||||
std::vector<Json> const &expected_models) {
|
learner->SaveModel(&model);
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model,
|
||||||
|
std::string objective) {
|
||||||
auto const world_size = collective::GetWorldSize();
|
auto const world_size = collective::GetWorldSize();
|
||||||
auto const rank = collective::GetRank();
|
auto const rank = collective::GetRank();
|
||||||
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||||
@ -28,76 +49,72 @@ void VerifyObjectives(size_t rows, size_t cols, std::vector<float> const &expect
|
|||||||
h_lower[i] = 1;
|
h_lower[i] = 1;
|
||||||
h_upper[i] = 10;
|
h_upper[i] = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (objective.find("rank:") != std::string::npos) {
|
||||||
|
auto h_label = dmat->Info().labels.HostView();
|
||||||
|
std::size_t k = 0;
|
||||||
|
for (auto &v : h_label) {
|
||||||
|
v = k % 2 == 0;
|
||||||
|
++k;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
||||||
|
|
||||||
auto i = 0;
|
auto model = MakeModel(objective, sliced);
|
||||||
for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
|
auto base_score = GetBaseScore(model);
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
ASSERT_EQ(base_score, expected_base_score);
|
||||||
learner->SetParam("tree_method", "approx");
|
ASSERT_EQ(model, expected_model);
|
||||||
learner->SetParam("objective", entry->name);
|
|
||||||
if (entry->name.find("quantile") != std::string::npos) {
|
|
||||||
learner->SetParam("quantile_alpha", "0.5");
|
|
||||||
}
|
|
||||||
if (entry->name.find("multi") != std::string::npos) {
|
|
||||||
learner->SetParam("num_class", "3");
|
|
||||||
}
|
|
||||||
learner->UpdateOneIter(0, sliced);
|
|
||||||
|
|
||||||
Json config{Object{}};
|
|
||||||
learner->SaveConfig(&config);
|
|
||||||
auto base_score = GetBaseScore(config);
|
|
||||||
ASSERT_EQ(base_score, expected_base_scores[i]);
|
|
||||||
|
|
||||||
Json model{Object{}};
|
|
||||||
learner->SaveModel(&model);
|
|
||||||
ASSERT_EQ(model, expected_models[i]);
|
|
||||||
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
||||||
|
std::unique_ptr<ServerForTest> server_;
|
||||||
|
static int const kWorldSize{3};
|
||||||
|
|
||||||
class FederatedLearnerTest : public BaseFederatedTest {
|
|
||||||
protected:
|
protected:
|
||||||
static auto constexpr kRows{16};
|
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
||||||
static auto constexpr kCols{16};
|
void TearDown() override { server_.reset(nullptr); }
|
||||||
|
|
||||||
|
void Run(std::string objective) {
|
||||||
|
static auto constexpr kRows{16};
|
||||||
|
static auto constexpr kCols{16};
|
||||||
|
|
||||||
|
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
||||||
|
|
||||||
|
auto &h_upper = dmat->Info().labels_upper_bound_.HostVector();
|
||||||
|
auto &h_lower = dmat->Info().labels_lower_bound_.HostVector();
|
||||||
|
h_lower.resize(kRows);
|
||||||
|
h_upper.resize(kRows);
|
||||||
|
for (size_t i = 0; i < kRows; ++i) {
|
||||||
|
h_lower[i] = 1;
|
||||||
|
h_upper[i] = 10;
|
||||||
|
}
|
||||||
|
if (objective.find("rank:") != std::string::npos) {
|
||||||
|
auto h_label = dmat->Info().labels.HostView();
|
||||||
|
std::size_t k = 0;
|
||||||
|
for (auto &v : h_label) {
|
||||||
|
v = k % 2 == 0;
|
||||||
|
++k;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto model = MakeModel(objective, dmat);
|
||||||
|
auto score = GetBaseScore(model);
|
||||||
|
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols,
|
||||||
|
score, model, objective);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(FederatedLearnerTest, Objectives) {
|
TEST_P(FederatedLearnerTest, Objective) {
|
||||||
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
std::string objective = GetParam();
|
||||||
|
this->Run(objective);
|
||||||
auto &h_upper = dmat->Info().labels_upper_bound_.HostVector();
|
|
||||||
auto &h_lower = dmat->Info().labels_lower_bound_.HostVector();
|
|
||||||
h_lower.resize(kRows);
|
|
||||||
h_upper.resize(kRows);
|
|
||||||
for (size_t i = 0; i < kRows; ++i) {
|
|
||||||
h_lower[i] = 1;
|
|
||||||
h_upper[i] = 10;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> base_scores;
|
|
||||||
std::vector<Json> models;
|
|
||||||
for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
|
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
|
||||||
learner->SetParam("tree_method", "approx");
|
|
||||||
learner->SetParam("objective", entry->name);
|
|
||||||
if (entry->name.find("quantile") != std::string::npos) {
|
|
||||||
learner->SetParam("quantile_alpha", "0.5");
|
|
||||||
}
|
|
||||||
if (entry->name.find("multi") != std::string::npos) {
|
|
||||||
learner->SetParam("num_class", "3");
|
|
||||||
}
|
|
||||||
learner->UpdateOneIter(0, dmat);
|
|
||||||
Json config{Object{}};
|
|
||||||
learner->SaveConfig(&config);
|
|
||||||
base_scores.emplace_back(GetBaseScore(config));
|
|
||||||
|
|
||||||
Json model{Object{}};
|
|
||||||
learner->SaveModel(&model);
|
|
||||||
models.emplace_back(model);
|
|
||||||
}
|
|
||||||
|
|
||||||
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyObjectives, kRows, kCols,
|
|
||||||
base_scores, models);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest,
|
||||||
|
::testing::ValuesIn(MakeObjNamesForTest()),
|
||||||
|
[](const ::testing::TestParamInfo<FederatedLearnerTest::ParamType> &info) {
|
||||||
|
return ObjTestNameGenerator(info);
|
||||||
|
});
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class FederatedServerTest : public BaseFederatedTest {
|
|||||||
TEST_F(FederatedServerTest, Allgather) {
|
TEST_F(FederatedServerTest, Allgather) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_address_);
|
threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_->Address());
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -83,7 +83,7 @@ TEST_F(FederatedServerTest, Allgather) {
|
|||||||
TEST_F(FederatedServerTest, Allreduce) {
|
TEST_F(FederatedServerTest, Allreduce) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_address_);
|
threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_->Address());
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -93,7 +93,7 @@ TEST_F(FederatedServerTest, Allreduce) {
|
|||||||
TEST_F(FederatedServerTest, Broadcast) {
|
TEST_F(FederatedServerTest, Broadcast) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_address_);
|
threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_->Address());
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -103,7 +103,7 @@ TEST_F(FederatedServerTest, Broadcast) {
|
|||||||
TEST_F(FederatedServerTest, Mixture) {
|
TEST_F(FederatedServerTest, Mixture) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_address_);
|
threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_->Address());
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
|
|||||||
@ -33,6 +33,7 @@
|
|||||||
#include "dmlc/registry.h" // for Registry
|
#include "dmlc/registry.h" // for Registry
|
||||||
#include "filesystem.h" // for TemporaryDirectory
|
#include "filesystem.h" // for TemporaryDirectory
|
||||||
#include "helpers.h" // for GetBaseScore, RandomDataGenerator
|
#include "helpers.h" // for GetBaseScore, RandomDataGenerator
|
||||||
|
#include "objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator
|
||||||
#include "xgboost/base.h" // for bst_float, Args, bst_feature_t, bst_int
|
#include "xgboost/base.h" // for bst_float, Args, bst_feature_t, bst_int
|
||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType
|
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType
|
||||||
@ -715,22 +716,9 @@ TEST_P(TestColumnSplit, Objective) {
|
|||||||
this->Run(objective);
|
this->Run(objective);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto MakeValues() {
|
INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit,
|
||||||
auto list = ::dmlc::Registry<::xgboost::ObjFunctionReg>::List();
|
::testing::ValuesIn(MakeObjNamesForTest()),
|
||||||
std::vector<std::string> names;
|
|
||||||
std::transform(list.cbegin(), list.cend(), std::back_inserter(names),
|
|
||||||
[](auto const* entry) { return entry->name; });
|
|
||||||
return names;
|
|
||||||
}
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit, ::testing::ValuesIn(MakeValues()),
|
|
||||||
[](const ::testing::TestParamInfo<TestColumnSplit::ParamType>& info) {
|
[](const ::testing::TestParamInfo<TestColumnSplit::ParamType>& info) {
|
||||||
auto name = std::string{info.param};
|
return ObjTestNameGenerator(info);
|
||||||
// Name must be a valid c++ symbol
|
|
||||||
auto it = std::find(name.cbegin(), name.cend(), ':');
|
|
||||||
if (it != name.cend()) {
|
|
||||||
name[std::distance(name.cbegin(), it)] = '_';
|
|
||||||
}
|
|
||||||
return name;
|
|
||||||
});
|
});
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user