diff --git a/tests/cpp/objective_helpers.h b/tests/cpp/objective_helpers.h new file mode 100644 index 000000000..b26470746 --- /dev/null +++ b/tests/cpp/objective_helpers.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2023, XGBoost contributors + */ +#include // for Registry +#include +#include // for ObjFunctionReg + +#include // for transform +#include // for back_insert_iterator, back_inserter +#include // for string +#include // for vector + +namespace xgboost { +inline auto MakeObjNamesForTest() { + auto list = ::dmlc::Registry<::xgboost::ObjFunctionReg>::List(); + std::vector names; + std::transform(list.cbegin(), list.cend(), std::back_inserter(names), + [](auto const* entry) { return entry->name; }); + return names; +} + +template +inline std::string ObjTestNameGenerator(const ::testing::TestParamInfo& info) { + auto name = std::string{info.param}; + // Name must be a valid c++ symbol + auto it = std::find(name.cbegin(), name.cend(), ':'); + if (it != name.cend()) { + name[std::distance(name.cbegin(), it)] = '_'; + } + return name; +}; +} // namespace xgboost diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index 7edfc5efc..10ba68b49 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -8,6 +8,7 @@ #include #include +#include // for thread, sleep_for #include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" @@ -33,13 +34,17 @@ inline std::string GetServerAddress() { namespace xgboost { -class BaseFederatedTest : public ::testing::Test { - protected: - void SetUp() override { +class ServerForTest { + std::string server_address_; + std::unique_ptr server_thread_; + std::unique_ptr server_; + + public: + explicit ServerForTest(std::int32_t world_size) { server_address_ = GetServerAddress(); - server_thread_.reset(new std::thread([this] { + server_thread_.reset(new std::thread([this, world_size] { grpc::ServerBuilder builder; - xgboost::federated::FederatedService service{kWorldSize}; + xgboost::federated::FederatedService service{world_size}; builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); builder.RegisterService(&service); server_ = builder.BuildAndStart(); @@ -47,15 +52,21 @@ class BaseFederatedTest : public ::testing::Test { })); } - void TearDown() override { + ~ServerForTest() { server_->Shutdown(); server_thread_->join(); } + auto Address() const { return server_address_; } +}; + +class BaseFederatedTest : public ::testing::Test { + protected: + void SetUp() override { server_ = std::make_unique(kWorldSize); } + + void TearDown() override { server_.reset(nullptr); } static int const kWorldSize{3}; - std::string server_address_; - std::unique_ptr server_thread_; - std::unique_ptr server_; + std::unique_ptr server_; }; template diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu index c4816ff18..a5e901f26 100644 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -29,7 +29,7 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) { TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back([rank, server_address = server_address_] { + threads.emplace_back([rank, server_address = server_->Address()] { FederatedCommunicator comm{kWorldSize, rank, server_address}; // Assign device 0 to all workers, since we run gtest in a single-GPU machine DeviceCommunicatorAdapter adapter{0, &comm}; @@ -52,7 +52,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { TEST_F(FederatedAdapterTest, DeviceAllGatherV) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back([rank, server_address = server_address_] { + threads.emplace_back([rank, server_address = server_->Address()] { FederatedCommunicator comm{kWorldSize, rank, server_address}; // Assign device 0 to all workers, since we run gtest in a single-GPU machine DeviceCommunicatorAdapter adapter{0, &comm}; diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 5177187c5..340849606 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -92,7 +92,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { config["federated_server_address"] = server_address; config["federated_world_size"] = std::string("1"); config["federated_rank"] = Integer(0); - auto *comm = FederatedCommunicator::Create(config); + FederatedCommunicator::Create(config); }; EXPECT_THROW(construct(), dmlc::Error); } @@ -104,7 +104,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { config["federated_server_address"] = server_address; config["federated_world_size"] = 1; config["federated_rank"] = std::string("0"); - auto *comm = FederatedCommunicator::Create(config); + FederatedCommunicator::Create(config); }; EXPECT_THROW(construct(), dmlc::Error); } @@ -125,7 +125,7 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) { TEST_F(FederatedCommunicatorTest, Allgather) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_); + threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_->Address()); } for (auto &thread : threads) { thread.join(); @@ -135,7 +135,7 @@ TEST_F(FederatedCommunicatorTest, Allgather) { TEST_F(FederatedCommunicatorTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_); + threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_->Address()); } for (auto &thread : threads) { thread.join(); @@ -145,7 +145,7 @@ TEST_F(FederatedCommunicatorTest, Allreduce) { TEST_F(FederatedCommunicatorTest, Broadcast) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_); + threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_->Address()); } for (auto &thread : threads) { thread.join(); diff --git a/tests/cpp/plugin/test_federated_data.cc b/tests/cpp/plugin/test_federated_data.cc index ed877131e..c6efb84d5 100644 --- a/tests/cpp/plugin/test_federated_data.cc +++ b/tests/cpp/plugin/test_federated_data.cc @@ -38,8 +38,8 @@ void VerifyLoadUri() { auto index = 0; int offsets[] = {0, 8, 17}; int offset = offsets[rank]; - for (auto row = 0; row < kRows; row++) { - for (auto col = 0; col < kCols; col++) { + for (std::size_t row = 0; row < kRows; row++) { + for (std::size_t col = 0; col < kCols; col++) { EXPECT_EQ(entries[index].index, col + offset); index++; } @@ -48,6 +48,6 @@ void VerifyLoadUri() { } TEST_F(FederatedDataTest, LoadUri) { - RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyLoadUri); + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLoadUri); } } // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_learner.cc b/tests/cpp/plugin/test_federated_learner.cc index fe7fe6854..85d0a2b7d 100644 --- a/tests/cpp/plugin/test_federated_learner.cc +++ b/tests/cpp/plugin/test_federated_learner.cc @@ -8,13 +8,34 @@ #include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" +#include "../../../src/common/linalg_op.h" #include "../helpers.h" +#include "../objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator #include "helpers.h" namespace xgboost { +namespace { +auto MakeModel(std::string objective, std::shared_ptr dmat) { + std::unique_ptr learner{Learner::Create({dmat})}; + learner->SetParam("tree_method", "approx"); + learner->SetParam("objective", objective); + if (objective.find("quantile") != std::string::npos) { + learner->SetParam("quantile_alpha", "0.5"); + } + if (objective.find("multi") != std::string::npos) { + learner->SetParam("num_class", "3"); + } + learner->UpdateOneIter(0, dmat); + Json config{Object{}}; + learner->SaveConfig(&config); -void VerifyObjectives(size_t rows, size_t cols, std::vector const &expected_base_scores, - std::vector const &expected_models) { + Json model{Object{}}; + learner->SaveModel(&model); + return model; +} + +void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model, + std::string objective) { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); std::shared_ptr dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; @@ -28,76 +49,72 @@ void VerifyObjectives(size_t rows, size_t cols, std::vector const &expect h_lower[i] = 1; h_upper[i] = 10; } + + if (objective.find("rank:") != std::string::npos) { + auto h_label = dmat->Info().labels.HostView(); + std::size_t k = 0; + for (auto &v : h_label) { + v = k % 2 == 0; + ++k; + } + } } std::shared_ptr sliced{dmat->SliceCol(world_size, rank)}; - auto i = 0; - for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { - std::unique_ptr learner{Learner::Create({sliced})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", entry->name); - if (entry->name.find("quantile") != std::string::npos) { - learner->SetParam("quantile_alpha", "0.5"); - } - if (entry->name.find("multi") != std::string::npos) { - learner->SetParam("num_class", "3"); - } - learner->UpdateOneIter(0, sliced); - - Json config{Object{}}; - learner->SaveConfig(&config); - auto base_score = GetBaseScore(config); - ASSERT_EQ(base_score, expected_base_scores[i]); - - Json model{Object{}}; - learner->SaveModel(&model); - ASSERT_EQ(model, expected_models[i]); - - i++; - } + auto model = MakeModel(objective, sliced); + auto base_score = GetBaseScore(model); + ASSERT_EQ(base_score, expected_base_score); + ASSERT_EQ(model, expected_model); } +} // namespace + +class FederatedLearnerTest : public ::testing::TestWithParam { + std::unique_ptr server_; + static int const kWorldSize{3}; -class FederatedLearnerTest : public BaseFederatedTest { protected: - static auto constexpr kRows{16}; - static auto constexpr kCols{16}; + void SetUp() override { server_ = std::make_unique(kWorldSize); } + void TearDown() override { server_.reset(nullptr); } + + void Run(std::string objective) { + static auto constexpr kRows{16}; + static auto constexpr kCols{16}; + + std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; + + auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); + auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); + h_lower.resize(kRows); + h_upper.resize(kRows); + for (size_t i = 0; i < kRows; ++i) { + h_lower[i] = 1; + h_upper[i] = 10; + } + if (objective.find("rank:") != std::string::npos) { + auto h_label = dmat->Info().labels.HostView(); + std::size_t k = 0; + for (auto &v : h_label) { + v = k % 2 == 0; + ++k; + } + } + + auto model = MakeModel(objective, dmat); + auto score = GetBaseScore(model); + + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols, + score, model, objective); + } }; -TEST_F(FederatedLearnerTest, Objectives) { - std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; - - auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); - auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); - h_lower.resize(kRows); - h_upper.resize(kRows); - for (size_t i = 0; i < kRows; ++i) { - h_lower[i] = 1; - h_upper[i] = 10; - } - - std::vector base_scores; - std::vector models; - for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { - std::unique_ptr learner{Learner::Create({dmat})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", entry->name); - if (entry->name.find("quantile") != std::string::npos) { - learner->SetParam("quantile_alpha", "0.5"); - } - if (entry->name.find("multi") != std::string::npos) { - learner->SetParam("num_class", "3"); - } - learner->UpdateOneIter(0, dmat); - Json config{Object{}}; - learner->SaveConfig(&config); - base_scores.emplace_back(GetBaseScore(config)); - - Json model{Object{}}; - learner->SaveModel(&model); - models.emplace_back(model); - } - - RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyObjectives, kRows, kCols, - base_scores, models); +TEST_P(FederatedLearnerTest, Objective) { + std::string objective = GetParam(); + this->Run(objective); } + +INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest, + ::testing::ValuesIn(MakeObjNamesForTest()), + [](const ::testing::TestParamInfo &info) { + return ObjTestNameGenerator(info); + }); } // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index 79e06bf5f..4dd2f3c40 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -73,7 +73,7 @@ class FederatedServerTest : public BaseFederatedTest { TEST_F(FederatedServerTest, Allgather) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_address_); + threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_->Address()); } for (auto& thread : threads) { thread.join(); @@ -83,7 +83,7 @@ TEST_F(FederatedServerTest, Allgather) { TEST_F(FederatedServerTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_address_); + threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_->Address()); } for (auto& thread : threads) { thread.join(); @@ -93,7 +93,7 @@ TEST_F(FederatedServerTest, Allreduce) { TEST_F(FederatedServerTest, Broadcast) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_address_); + threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_->Address()); } for (auto& thread : threads) { thread.join(); @@ -103,7 +103,7 @@ TEST_F(FederatedServerTest, Broadcast) { TEST_F(FederatedServerTest, Mixture) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_address_); + threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_->Address()); } for (auto& thread : threads) { thread.join(); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index b43a0ecc1..91e8070c2 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -33,6 +33,7 @@ #include "dmlc/registry.h" // for Registry #include "filesystem.h" // for TemporaryDirectory #include "helpers.h" // for GetBaseScore, RandomDataGenerator +#include "objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator #include "xgboost/base.h" // for bst_float, Args, bst_feature_t, bst_int #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for DMatrix, MetaInfo, DataType @@ -715,22 +716,9 @@ TEST_P(TestColumnSplit, Objective) { this->Run(objective); } -auto MakeValues() { - auto list = ::dmlc::Registry<::xgboost::ObjFunctionReg>::List(); - std::vector names; - std::transform(list.cbegin(), list.cend(), std::back_inserter(names), - [](auto const* entry) { return entry->name; }); - return names; -} - -INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit, ::testing::ValuesIn(MakeValues()), +INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit, + ::testing::ValuesIn(MakeObjNamesForTest()), [](const ::testing::TestParamInfo& info) { - auto name = std::string{info.param}; - // Name must be a valid c++ symbol - auto it = std::find(name.cbegin(), name.cend(), ':'); - if (it != name.cend()) { - name[std::distance(name.cbegin(), it)] = '_'; - } - return name; + return ObjTestNameGenerator(info); }); } // namespace xgboost