/*! * Copyright 2023 XGBoost contributors */ #include #include #include #include #include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" #include "../../../src/common/linalg_op.h" #include "../helpers.h" #include "../objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator #include "helpers.h" namespace xgboost { namespace { auto MakeModel(std::string objective, std::shared_ptr dmat) { std::unique_ptr learner{Learner::Create({dmat})}; learner->SetParam("tree_method", "approx"); learner->SetParam("objective", objective); if (objective.find("quantile") != std::string::npos) { learner->SetParam("quantile_alpha", "0.5"); } if (objective.find("multi") != std::string::npos) { learner->SetParam("num_class", "3"); } learner->UpdateOneIter(0, dmat); Json config{Object{}}; learner->SaveConfig(&config); Json model{Object{}}; learner->SaveModel(&model); return model; } void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model, std::string objective) { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); std::shared_ptr dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; if (rank == 0) { auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); h_lower.resize(rows); h_upper.resize(rows); for (size_t i = 0; i < rows; ++i) { h_lower[i] = 1; h_upper[i] = 10; } if (objective.find("rank:") != std::string::npos) { auto h_label = dmat->Info().labels.HostView(); std::size_t k = 0; for (auto &v : h_label) { v = k % 2 == 0; ++k; } } } std::shared_ptr sliced{dmat->SliceCol(world_size, rank)}; auto model = MakeModel(objective, sliced); auto base_score = GetBaseScore(model); ASSERT_EQ(base_score, expected_base_score); ASSERT_EQ(model, expected_model); } } // namespace class FederatedLearnerTest : public ::testing::TestWithParam { std::unique_ptr server_; static int const kWorldSize{3}; protected: void SetUp() override { server_ = std::make_unique(kWorldSize); } void TearDown() override { server_.reset(nullptr); } void Run(std::string objective) { static auto constexpr kRows{16}; static auto constexpr kCols{16}; std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); h_lower.resize(kRows); h_upper.resize(kRows); for (size_t i = 0; i < kRows; ++i) { h_lower[i] = 1; h_upper[i] = 10; } if (objective.find("rank:") != std::string::npos) { auto h_label = dmat->Info().labels.HostView(); std::size_t k = 0; for (auto &v : h_label) { v = k % 2 == 0; ++k; } } auto model = MakeModel(objective, dmat); auto score = GetBaseScore(model); RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols, score, model, objective); } }; TEST_P(FederatedLearnerTest, Objective) { std::string objective = GetParam(); this->Run(objective); } INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest, ::testing::ValuesIn(MakeObjNamesForTest()), [](const ::testing::TestParamInfo &info) { return ObjTestNameGenerator(info); }); } // namespace xgboost