From 42d100de188d6f1df29621f8f72c4e429bbfee22 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 19 Apr 2023 00:39:11 -0700 Subject: [PATCH] Make sure metrics work with federated learning (#9037) --- src/collective/aggregator.h | 62 ++++++ src/learner.cc | 43 +--- src/metric/auc.cc | 4 +- src/metric/metric_common.h | 9 +- src/metric/rank_metric.cc | 28 ++- tests/cpp/helpers.cc | 4 +- tests/cpp/metric/test_survival_metric.cu | 99 +-------- tests/cpp/metric/test_survival_metric.h | 107 +++++++++ tests/cpp/plugin/helpers.h | 2 +- tests/cpp/plugin/test_federated_learner.cc | 2 +- tests/cpp/plugin/test_federated_metrics.cc | 243 +++++++++++++++++++++ 11 files changed, 451 insertions(+), 152 deletions(-) create mode 100644 src/collective/aggregator.h create mode 100644 tests/cpp/metric/test_survival_metric.h create mode 100644 tests/cpp/plugin/test_federated_metrics.cc diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h new file mode 100644 index 000000000..ee499b4d1 --- /dev/null +++ b/src/collective/aggregator.h @@ -0,0 +1,62 @@ +/** + * Copyright 2023 by XGBoost contributors + * + * Higher level functions built on top the Communicator API, taking care of behavioral differences + * between row-split vs column-split distributed training, and horizontal vs vertical federated + * learning. + */ +#pragma once +#include + +#include +#include +#include + +#include "communicator-inl.h" + +namespace xgboost { +namespace collective { + +/** + * @brief Apply the given function where the labels are. + * + * Normally all the workers have access to the labels, so the function is just applied locally. In + * vertical federated learning, we assume labels are only available on worker 0, so the function is + * applied there, with the results broadcast to other workers. + * + * @tparam Function The function used to calculate the results. + * @tparam Args Arguments to the function. + * @param info MetaInfo about the DMatrix. + * @param buffer The buffer storing the results. + * @param size The size of the buffer. + * @param function The function used to calculate the results. + * @param args Arguments to the function. + */ +template +void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function, + Args&&... args) { + if (info.IsVerticalFederated()) { + // We assume labels are only available on worker 0, so the calculation is done there and result + // broadcast to other workers. + std::vector message(1024); + if (collective::GetRank() == 0) { + try { + std::forward(function)(std::forward(args)...); + } catch (dmlc::Error& e) { + strncpy(&message[0], e.what(), message.size()); + message.back() = '\0'; + } + } + collective::Broadcast(&message[0], message.size(), 0); + if (strlen(&message[0]) == 0) { + collective::Broadcast(buffer, size, 0); + } else { + LOG(FATAL) << &message[0]; + } + } else { + std::forward(function)(std::forward(args)...); + } +} + +} // namespace collective +} // namespace xgboost diff --git a/src/learner.cc b/src/learner.cc index 1150a2355..78297404b 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -34,6 +34,7 @@ #include // for pair, as_const, move, swap #include // for vector +#include "collective/aggregator.h" // for ApplyWithLabels #include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed #include "collective/communicator.h" // for Operation #include "common/api_entry.h" // for XGBAPIThreadLocalEntry @@ -859,22 +860,10 @@ class LearnerConfiguration : public Learner { } void InitEstimation(MetaInfo const& info, linalg::Tensor* base_score) { - // Special handling for vertical federated learning. - if (info.IsVerticalFederated()) { - // We assume labels are only available on worker 0, so the estimation is calculated there - // and broadcast to other workers. - if (collective::GetRank() == 0) { - UsePtr(obj_)->InitEstimation(info, base_score); - collective::Broadcast(base_score->Data()->HostPointer(), - sizeof(bst_float) * base_score->Size(), 0); - } else { - base_score->Reshape(1); - collective::Broadcast(base_score->Data()->HostPointer(), - sizeof(bst_float) * base_score->Size(), 0); - } - } else { - UsePtr(obj_)->InitEstimation(info, base_score); - } + base_score->Reshape(1); + collective::ApplyWithLabels(info, base_score->Data()->HostPointer(), + sizeof(bst_float) * base_score->Size(), + [&] { UsePtr(obj_)->InitEstimation(info, base_score); }); } }; @@ -1486,24 +1475,10 @@ class LearnerImpl : public LearnerIO { private: void GetGradient(HostDeviceVector const& preds, MetaInfo const& info, int iteration, HostDeviceVector* out_gpair) { - // Special handling for vertical federated learning. - if (info.IsVerticalFederated()) { - // We assume labels are only available on worker 0, so the gradients are calculated there - // and broadcast to other workers. - if (collective::GetRank() == 0) { - obj_->GetGradient(preds, info, iteration, out_gpair); - collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair), - 0); - } else { - CHECK_EQ(info.labels.Size(), 0) - << "In vertical federated learning, labels should only be on the first worker"; - out_gpair->Resize(preds.Size()); - collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair), - 0); - } - } else { - obj_->GetGradient(preds, info, iteration, out_gpair); - } + out_gpair->Resize(preds.Size()); + collective::ApplyWithLabels(info, out_gpair->HostPointer(), + out_gpair->Size() * sizeof(GradientPair), + [&] { obj_->GetGradient(preds, info, iteration, out_gpair); }); } /*! \brief random number transformation seed. */ diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 2d4becfa8..bde3127ed 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -270,7 +270,9 @@ class EvalAUC : public MetricNoCache { } // We use the global size to handle empty dataset. std::array meta{info.labels.Size(), preds.Size()}; - collective::Allreduce(meta.data(), meta.size()); + if (!info.IsVerticalFederated()) { + collective::Allreduce(meta.data(), meta.size()); + } if (meta[0] == 0) { // Empty across all workers, which is not supported. auc = std::numeric_limits::quiet_NaN(); diff --git a/src/metric/metric_common.h b/src/metric/metric_common.h index 5fbd6f256..a6fad7158 100644 --- a/src/metric/metric_common.h +++ b/src/metric/metric_common.h @@ -9,6 +9,8 @@ #include // shared_ptr #include +#include "../collective/aggregator.h" +#include "../collective/communicator-inl.h" #include "../common/common.h" #include "xgboost/metric.h" @@ -20,7 +22,12 @@ class MetricNoCache : public Metric { virtual double Eval(HostDeviceVector const &predts, MetaInfo const &info) = 0; double Evaluate(HostDeviceVector const &predts, std::shared_ptr p_fmat) final { - return this->Eval(predts, p_fmat->Info()); + double result{0.0}; + auto const& info = p_fmat->Info(); + collective::ApplyWithLabels(info, &result, sizeof(double), [&] { + result = this->Eval(predts, info); + }); + return result; } }; diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 62efd0876..000b88e80 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -28,9 +28,8 @@ #include // for stable_sort, copy, fill_n, min, max #include // for array #include // for log, sqrt -#include // for size_t, std -#include // for uint32_t #include // for less, greater +#include // for numeric_limits #include // for operator!=, _Rb_tree_const_iterator #include // for allocator, unique_ptr, shared_ptr, __shared_... #include // for accumulate @@ -39,15 +38,11 @@ #include // for pair, make_pair #include // for vector -#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce -#include "../collective/communicator.h" // for Operation +#include "../collective/aggregator.h" // for ApplyWithLabels #include "../common/algorithm.h" // for ArgSort, Sort #include "../common/linalg_op.h" // for cbegin, cend #include "../common/math.h" // for CmpFirst #include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights -#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName -#include "../common/threading_utils.h" // for ParallelFor -#include "../common/transform_iterator.h" // for IndexTransformIter #include "dmlc/common.h" // for OMPException #include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult #include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args @@ -59,7 +54,6 @@ #include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT... #include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ #include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric -#include "xgboost/span.h" // for Span, operator!= #include "xgboost/string_view.h" // for StringView namespace { @@ -385,15 +379,19 @@ class EvalRankWithCache : public Metric { } double Evaluate(HostDeviceVector const& preds, std::shared_ptr p_fmat) override { + double result{0.0}; auto const& info = p_fmat->Info(); - auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_); - if (p_cache->Param() != param_) { - p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_); - } - CHECK(p_cache->Param() == param_); - CHECK_EQ(preds.Size(), info.labels.Size()); + collective::ApplyWithLabels(info, &result, sizeof(double), [&] { + auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_); + if (p_cache->Param() != param_) { + p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_); + } + CHECK(p_cache->Param() == param_); + CHECK_EQ(preds.Size(), info.labels.Size()); - return this->Eval(preds, info, p_cache); + result = this->Eval(preds, info, p_cache); + }); + return result; } virtual double Eval(HostDeviceVector const& preds, MetaInfo const& info, diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index a8b974f03..76fd2f967 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -189,7 +189,9 @@ double GetMultiMetricEval(xgboost::Metric* metric, info.weights_.HostVector() = weights; info.group_ptr_ = groups; info.data_split_mode = data_split_mode; - + if (info.IsVerticalFederated() && xgboost::collective::GetRank() != 0) { + info.labels.Reshape(0); + } return metric->Evaluate(preds, p_fmat); } diff --git a/tests/cpp/metric/test_survival_metric.cu b/tests/cpp/metric/test_survival_metric.cu index d7ac54860..723f306e4 100644 --- a/tests/cpp/metric/test_survival_metric.cu +++ b/tests/cpp/metric/test_survival_metric.cu @@ -2,109 +2,13 @@ * Copyright (c) by Contributors 2020 */ #include -#include +#include "test_survival_metric.h" #include "xgboost/metric.h" -#include "../helpers.h" -#include "../../../src/common/survival_util.h" /** Tests for Survival metrics that should run both on CPU and GPU **/ namespace xgboost { namespace common { -namespace { -inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) { - auto ctx = CreateEmptyGenericParam(device); - std::unique_ptr metric{Metric::Create(name.c_str(), &ctx)}; - metric->Configure(Args{}); - - HostDeviceVector predts; - auto p_fmat = EmptyDMatrix(); - MetaInfo& info = p_fmat->Info(); - auto &h_predts = predts.HostVector(); - - SimpleLCG lcg; - SimpleRealUniformDistribution dist{0.0f, 1.0f}; - - size_t n_samples = 2048; - h_predts.resize(n_samples); - - for (size_t i = 0; i < n_samples; ++i) { - h_predts[i] = dist(&lcg); - } - - auto &h_upper = info.labels_upper_bound_.HostVector(); - auto &h_lower = info.labels_lower_bound_.HostVector(); - h_lower.resize(n_samples); - h_upper.resize(n_samples); - for (size_t i = 0; i < n_samples; ++i) { - h_lower[i] = 1; - h_upper[i] = 10; - } - - auto result = metric->Evaluate(predts, p_fmat); - for (size_t i = 0; i < 8; ++i) { - ASSERT_EQ(metric->Evaluate(predts, p_fmat), result); - } -} - -void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - /** - * Test aggregate output from the AFT metric over a small test data set. - * This is unlike AFTLoss.* tests, which verify metric values over individual data points. - **/ - auto p_fmat = EmptyDMatrix(); - MetaInfo& info = p_fmat->Info(); - info.num_row_ = 4; - info.labels_lower_bound_.HostVector() - = { 100.0f, 0.0f, 60.0f, 16.0f }; - info.labels_upper_bound_.HostVector() - = { 100.0f, 20.0f, std::numeric_limits::infinity(), 200.0f }; - info.weights_.HostVector() = std::vector(); - info.data_split_mode = data_split_mode; - HostDeviceVector preds(4, std::log(64)); - - struct TestCase { - std::string dist_type; - bst_float reference_value; - }; - for (const auto& test_case : std::vector{ {"normal", 2.1508f}, {"logistic", 2.1804f}, - {"extreme", 2.0706f} }) { - std::unique_ptr metric(Metric::Create("aft-nloglik", &ctx)); - metric->Configure({ {"aft_loss_distribution", test_case.dist_type}, - {"aft_loss_distribution_scale", "1.0"} }); - EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4); - } -} - -void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - auto p_fmat = EmptyDMatrix(); - MetaInfo& info = p_fmat->Info(); - info.num_row_ = 4; - info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f }; - info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f }; - info.weights_.HostVector() = std::vector(); - info.data_split_mode = data_split_mode; - HostDeviceVector preds(4, std::log(60.0f)); - - std::unique_ptr metric(Metric::Create("interval-regression-accuracy", &ctx)); - EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f); - info.labels_lower_bound_.HostVector()[2] = 70.0f; - EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f); - info.labels_upper_bound_.HostVector()[2] = std::numeric_limits::infinity(); - EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f); - info.labels_upper_bound_.HostVector()[3] = std::numeric_limits::infinity(); - EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f); - info.labels_lower_bound_.HostVector()[0] = 70.0f; - EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f); - - CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX); -} -} // anonymous namespace - TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) { VerifyAFTNegLogLik(); } TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikRowSplit) { @@ -140,6 +44,5 @@ TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) { CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX); } - } // namespace common } // namespace xgboost diff --git a/tests/cpp/metric/test_survival_metric.h b/tests/cpp/metric/test_survival_metric.h new file mode 100644 index 000000000..75414733d --- /dev/null +++ b/tests/cpp/metric/test_survival_metric.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020-2023 by XGBoost Contributors + */ +#pragma once +#include + +#include + +#include "../../../src/common/survival_util.h" +#include "../helpers.h" +#include "xgboost/metric.h" + +namespace xgboost { +namespace common { +inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) { + auto ctx = CreateEmptyGenericParam(device); + std::unique_ptr metric{Metric::Create(name.c_str(), &ctx)}; + metric->Configure(Args{}); + + HostDeviceVector predts; + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + auto &h_predts = predts.HostVector(); + + SimpleLCG lcg; + SimpleRealUniformDistribution dist{0.0f, 1.0f}; + + size_t n_samples = 2048; + h_predts.resize(n_samples); + + for (size_t i = 0; i < n_samples; ++i) { + h_predts[i] = dist(&lcg); + } + + auto &h_upper = info.labels_upper_bound_.HostVector(); + auto &h_lower = info.labels_lower_bound_.HostVector(); + h_lower.resize(n_samples); + h_upper.resize(n_samples); + for (size_t i = 0; i < n_samples; ++i) { + h_lower[i] = 1; + h_upper[i] = 10; + } + + auto result = metric->Evaluate(predts, p_fmat); + for (size_t i = 0; i < 8; ++i) { + ASSERT_EQ(metric->Evaluate(predts, p_fmat), result); + } +} + +inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) { + auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); + + /** + * Test aggregate output from the AFT metric over a small test data set. + * This is unlike AFTLoss.* tests, which verify metric values over individual data points. + **/ + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + info.num_row_ = 4; + info.labels_lower_bound_.HostVector() + = { 100.0f, 0.0f, 60.0f, 16.0f }; + info.labels_upper_bound_.HostVector() + = { 100.0f, 20.0f, std::numeric_limits::infinity(), 200.0f }; + info.weights_.HostVector() = std::vector(); + info.data_split_mode = data_split_mode; + HostDeviceVector preds(4, std::log(64)); + + struct TestCase { + std::string dist_type; + bst_float reference_value; + }; + for (const auto& test_case : std::vector{ {"normal", 2.1508f}, {"logistic", 2.1804f}, + {"extreme", 2.0706f} }) { + std::unique_ptr metric(Metric::Create("aft-nloglik", &ctx)); + metric->Configure({ {"aft_loss_distribution", test_case.dist_type}, + {"aft_loss_distribution_scale", "1.0"} }); + EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4); + } +} + +inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) { + auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); + + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + info.num_row_ = 4; + info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f }; + info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f }; + info.weights_.HostVector() = std::vector(); + info.data_split_mode = data_split_mode; + HostDeviceVector preds(4, std::log(60.0f)); + + std::unique_ptr metric(Metric::Create("interval-regression-accuracy", &ctx)); + EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f); + info.labels_lower_bound_.HostVector()[2] = 70.0f; + EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f); + info.labels_upper_bound_.HostVector()[2] = std::numeric_limits::infinity(); + EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f); + info.labels_upper_bound_.HostVector()[3] = std::numeric_limits::infinity(); + EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f); + info.labels_lower_bound_.HostVector()[0] = 70.0f; + EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f); + + CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX); +} +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index 10ba68b49..41e5a63e5 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -65,7 +65,7 @@ class BaseFederatedTest : public ::testing::Test { void TearDown() override { server_.reset(nullptr); } - static int const kWorldSize{3}; + static int constexpr kWorldSize{3}; std::unique_ptr server_; }; diff --git a/tests/cpp/plugin/test_federated_learner.cc b/tests/cpp/plugin/test_federated_learner.cc index 85d0a2b7d..b7066b6a0 100644 --- a/tests/cpp/plugin/test_federated_learner.cc +++ b/tests/cpp/plugin/test_federated_learner.cc @@ -70,7 +70,7 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e class FederatedLearnerTest : public ::testing::TestWithParam { std::unique_ptr server_; - static int const kWorldSize{3}; + static int constexpr kWorldSize{3}; protected: void SetUp() override { server_ = std::make_unique(kWorldSize); } diff --git a/tests/cpp/plugin/test_federated_metrics.cc b/tests/cpp/plugin/test_federated_metrics.cc new file mode 100644 index 000000000..1bdda567f --- /dev/null +++ b/tests/cpp/plugin/test_federated_metrics.cc @@ -0,0 +1,243 @@ +/*! + * Copyright 2023 XGBoost contributors + */ +#include + +#include "../metric/test_auc.h" +#include "../metric/test_elementwise_metric.h" +#include "../metric/test_multiclass_metric.h" +#include "../metric/test_rank_metric.h" +#include "../metric/test_survival_metric.h" +#include "helpers.h" + +namespace { +class FederatedMetricTest : public xgboost::BaseFederatedTest {}; +} // anonymous namespace + +namespace xgboost { +namespace metric { +TEST_F(FederatedMetricTest, BinaryAUCRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, BinaryAUCColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, MultiClassAUCRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, MultiClassAUCColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, RankingAUCRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, RankingAUCColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, PRAUCRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, PRAUCColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, MultiClassPRAUCRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, MultiClassPRAUCColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, RankingPRAUCRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, RankingPRAUCColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, RMSERowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, RMSEColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, RMSLERowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, RMSLEColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, MAERowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, MAEColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, MAPERowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, MAPEColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, MPHERowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, MPHEColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, LogLossRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, LogLossColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, ErrorRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, ErrorColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, PoissonNegLogLikRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, PoissonNegLogLikColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, MultiRMSERowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, MultiRMSEColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, QuantileRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, QuantileColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, MultiClassErrorRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, MultiClassErrorColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, MultiClassLogLossRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, MultiClassLogLossColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, PrecisionRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, PrecisionColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, NDCGRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, NDCGColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, MAPRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, MAPColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, NDCGExpGainRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, NDCGExpGainColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain, + DataSplitMode::kCol); +} +} // namespace metric +} // namespace xgboost + +namespace xgboost { +namespace common { +TEST_F(FederatedMetricTest, AFTNegLogLikRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, AFTNegLogLikColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik, + DataSplitMode::kCol); +} + +TEST_F(FederatedMetricTest, IntervalRegressionAccuracyRowSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy, + DataSplitMode::kRow); +} + +TEST_F(FederatedMetricTest, IntervalRegressionAccuracyColumnSplit) { + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy, + DataSplitMode::kCol); +} +} // namespace common +} // namespace xgboost