Make sure metrics work with federated learning (#9037)

This commit is contained in:
Rong Ou 2023-04-19 00:39:11 -07:00 committed by GitHub
parent ef13dd31b1
commit 42d100de18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 451 additions and 152 deletions

View File

@ -0,0 +1,62 @@
/**
* Copyright 2023 by XGBoost contributors
*
* Higher level functions built on top the Communicator API, taking care of behavioral differences
* between row-split vs column-split distributed training, and horizontal vs vertical federated
* learning.
*/
#pragma once
#include <xgboost/data.h>
#include <string>
#include <utility>
#include <vector>
#include "communicator-inl.h"
namespace xgboost {
namespace collective {
/**
* @brief Apply the given function where the labels are.
*
* Normally all the workers have access to the labels, so the function is just applied locally. In
* vertical federated learning, we assume labels are only available on worker 0, so the function is
* applied there, with the results broadcast to other workers.
*
* @tparam Function The function used to calculate the results.
* @tparam Args Arguments to the function.
* @param info MetaInfo about the DMatrix.
* @param buffer The buffer storing the results.
* @param size The size of the buffer.
* @param function The function used to calculate the results.
* @param args Arguments to the function.
*/
template <typename Function, typename... Args>
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function,
Args&&... args) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::vector<char> message(1024);
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)(std::forward<Args>(args)...);
} catch (dmlc::Error& e) {
strncpy(&message[0], e.what(), message.size());
message.back() = '\0';
}
}
collective::Broadcast(&message[0], message.size(), 0);
if (strlen(&message[0]) == 0) {
collective::Broadcast(buffer, size, 0);
} else {
LOG(FATAL) << &message[0];
}
} else {
std::forward<Function>(function)(std::forward<Args>(args)...);
}
}
} // namespace collective
} // namespace xgboost

View File

@ -34,6 +34,7 @@
#include <utility> // for pair, as_const, move, swap #include <utility> // for pair, as_const, move, swap
#include <vector> // for vector #include <vector> // for vector
#include "collective/aggregator.h" // for ApplyWithLabels
#include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed #include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed
#include "collective/communicator.h" // for Operation #include "collective/communicator.h" // for Operation
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry #include "common/api_entry.h" // for XGBAPIThreadLocalEntry
@ -859,22 +860,10 @@ class LearnerConfiguration : public Learner {
} }
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) { void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
// Special handling for vertical federated learning. base_score->Reshape(1);
if (info.IsVerticalFederated()) { collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
// We assume labels are only available on worker 0, so the estimation is calculated there sizeof(bst_float) * base_score->Size(),
// and broadcast to other workers. [&] { UsePtr(obj_)->InitEstimation(info, base_score); });
if (collective::GetRank() == 0) {
UsePtr(obj_)->InitEstimation(info, base_score);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
} else {
base_score->Reshape(1);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
}
} else {
UsePtr(obj_)->InitEstimation(info, base_score);
}
} }
}; };
@ -1486,24 +1475,10 @@ class LearnerImpl : public LearnerIO {
private: private:
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration, void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) { HostDeviceVector<GradientPair>* out_gpair) {
// Special handling for vertical federated learning. out_gpair->Resize(preds.Size());
if (info.IsVerticalFederated()) { collective::ApplyWithLabels(info, out_gpair->HostPointer(),
// We assume labels are only available on worker 0, so the gradients are calculated there out_gpair->Size() * sizeof(GradientPair),
// and broadcast to other workers. [&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
if (collective::GetRank() == 0) {
obj_->GetGradient(preds, info, iteration, out_gpair);
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
} else {
CHECK_EQ(info.labels.Size(), 0)
<< "In vertical federated learning, labels should only be on the first worker";
out_gpair->Resize(preds.Size());
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
}
} else {
obj_->GetGradient(preds, info, iteration, out_gpair);
}
} }
/*! \brief random number transformation seed. */ /*! \brief random number transformation seed. */

View File

@ -270,7 +270,9 @@ class EvalAUC : public MetricNoCache {
} }
// We use the global size to handle empty dataset. // We use the global size to handle empty dataset.
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()}; std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size()); if (!info.IsVerticalFederated()) {
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
}
if (meta[0] == 0) { if (meta[0] == 0) {
// Empty across all workers, which is not supported. // Empty across all workers, which is not supported.
auc = std::numeric_limits<double>::quiet_NaN(); auc = std::numeric_limits<double>::quiet_NaN();

View File

@ -9,6 +9,8 @@
#include <memory> // shared_ptr #include <memory> // shared_ptr
#include <string> #include <string>
#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../common/common.h" #include "../common/common.h"
#include "xgboost/metric.h" #include "xgboost/metric.h"
@ -20,7 +22,12 @@ class MetricNoCache : public Metric {
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0; virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final { double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
return this->Eval(predts, p_fmat->Info()); double result{0.0};
auto const& info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
result = this->Eval(predts, info);
});
return result;
} }
}; };

View File

@ -28,9 +28,8 @@
#include <algorithm> // for stable_sort, copy, fill_n, min, max #include <algorithm> // for stable_sort, copy, fill_n, min, max
#include <array> // for array #include <array> // for array
#include <cmath> // for log, sqrt #include <cmath> // for log, sqrt
#include <cstddef> // for size_t, std
#include <cstdint> // for uint32_t
#include <functional> // for less, greater #include <functional> // for less, greater
#include <limits> // for numeric_limits
#include <map> // for operator!=, _Rb_tree_const_iterator #include <map> // for operator!=, _Rb_tree_const_iterator
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_... #include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
#include <numeric> // for accumulate #include <numeric> // for accumulate
@ -39,15 +38,11 @@
#include <utility> // for pair, make_pair #include <utility> // for pair, make_pair
#include <vector> // for vector #include <vector> // for vector
#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce #include "../collective/aggregator.h" // for ApplyWithLabels
#include "../collective/communicator.h" // for Operation
#include "../common/algorithm.h" // for ArgSort, Sort #include "../common/algorithm.h" // for ArgSort, Sort
#include "../common/linalg_op.h" // for cbegin, cend #include "../common/linalg_op.h" // for cbegin, cend
#include "../common/math.h" // for CmpFirst #include "../common/math.h" // for CmpFirst
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights #include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/transform_iterator.h" // for IndexTransformIter
#include "dmlc/common.h" // for OMPException #include "dmlc/common.h" // for OMPException
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult #include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args #include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
@ -59,7 +54,6 @@
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT... #include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ #include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric #include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
#include "xgboost/span.h" // for Span, operator!=
#include "xgboost/string_view.h" // for StringView #include "xgboost/string_view.h" // for StringView
namespace { namespace {
@ -385,15 +379,19 @@ class EvalRankWithCache : public Metric {
} }
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override { double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
double result{0.0};
auto const& info = p_fmat->Info(); auto const& info = p_fmat->Info();
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_); collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
if (p_cache->Param() != param_) { auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_); if (p_cache->Param() != param_) {
} p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
CHECK(p_cache->Param() == param_); }
CHECK_EQ(preds.Size(), info.labels.Size()); CHECK(p_cache->Param() == param_);
CHECK_EQ(preds.Size(), info.labels.Size());
return this->Eval(preds, info, p_cache); result = this->Eval(preds, info, p_cache);
});
return result;
} }
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info, virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,

View File

@ -189,7 +189,9 @@ double GetMultiMetricEval(xgboost::Metric* metric,
info.weights_.HostVector() = weights; info.weights_.HostVector() = weights;
info.group_ptr_ = groups; info.group_ptr_ = groups;
info.data_split_mode = data_split_mode; info.data_split_mode = data_split_mode;
if (info.IsVerticalFederated() && xgboost::collective::GetRank() != 0) {
info.labels.Reshape(0);
}
return metric->Evaluate(preds, p_fmat); return metric->Evaluate(preds, p_fmat);
} }

View File

@ -2,109 +2,13 @@
* Copyright (c) by Contributors 2020 * Copyright (c) by Contributors 2020
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include "test_survival_metric.h"
#include "xgboost/metric.h" #include "xgboost/metric.h"
#include "../helpers.h"
#include "../../../src/common/survival_util.h"
/** Tests for Survival metrics that should run both on CPU and GPU **/ /** Tests for Survival metrics that should run both on CPU and GPU **/
namespace xgboost { namespace xgboost {
namespace common { namespace common {
namespace {
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
auto ctx = CreateEmptyGenericParam(device);
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
metric->Configure(Args{});
HostDeviceVector<float> predts;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
auto &h_predts = predts.HostVector();
SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
size_t n_samples = 2048;
h_predts.resize(n_samples);
for (size_t i = 0; i < n_samples; ++i) {
h_predts[i] = dist(&lcg);
}
auto &h_upper = info.labels_upper_bound_.HostVector();
auto &h_lower = info.labels_lower_bound_.HostVector();
h_lower.resize(n_samples);
h_upper.resize(n_samples);
for (size_t i = 0; i < n_samples; ++i) {
h_lower[i] = 1;
h_upper[i] = 10;
}
auto result = metric->Evaluate(predts, p_fmat);
for (size_t i = 0; i < 8; ++i) {
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
}
}
void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
/**
* Test aggregate output from the AFT metric over a small test data set.
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
**/
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.num_row_ = 4;
info.labels_lower_bound_.HostVector()
= { 100.0f, 0.0f, 60.0f, 16.0f };
info.labels_upper_bound_.HostVector()
= { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
info.weights_.HostVector() = std::vector<bst_float>();
info.data_split_mode = data_split_mode;
HostDeviceVector<bst_float> preds(4, std::log(64));
struct TestCase {
std::string dist_type;
bst_float reference_value;
};
for (const auto& test_case : std::vector<TestCase>{ {"normal", 2.1508f}, {"logistic", 2.1804f},
{"extreme", 2.0706f} }) {
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
{"aft_loss_distribution_scale", "1.0"} });
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
}
}
void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.num_row_ = 4;
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
info.weights_.HostVector() = std::vector<bst_float>();
info.data_split_mode = data_split_mode;
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
info.labels_lower_bound_.HostVector()[2] = 70.0f;
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_lower_bound_.HostVector()[0] = 70.0f;
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
}
} // anonymous namespace
TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) { VerifyAFTNegLogLik(); } TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) { VerifyAFTNegLogLik(); }
TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikRowSplit) { TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikRowSplit) {
@ -140,6 +44,5 @@ TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) {
CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX); CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX);
} }
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -0,0 +1,107 @@
/**
* Copyright 2020-2023 by XGBoost Contributors
*/
#pragma once
#include <gtest/gtest.h>
#include <cmath>
#include "../../../src/common/survival_util.h"
#include "../helpers.h"
#include "xgboost/metric.h"
namespace xgboost {
namespace common {
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
auto ctx = CreateEmptyGenericParam(device);
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
metric->Configure(Args{});
HostDeviceVector<float> predts;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
auto &h_predts = predts.HostVector();
SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
size_t n_samples = 2048;
h_predts.resize(n_samples);
for (size_t i = 0; i < n_samples; ++i) {
h_predts[i] = dist(&lcg);
}
auto &h_upper = info.labels_upper_bound_.HostVector();
auto &h_lower = info.labels_lower_bound_.HostVector();
h_lower.resize(n_samples);
h_upper.resize(n_samples);
for (size_t i = 0; i < n_samples; ++i) {
h_lower[i] = 1;
h_upper[i] = 10;
}
auto result = metric->Evaluate(predts, p_fmat);
for (size_t i = 0; i < 8; ++i) {
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
}
}
inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
/**
* Test aggregate output from the AFT metric over a small test data set.
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
**/
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.num_row_ = 4;
info.labels_lower_bound_.HostVector()
= { 100.0f, 0.0f, 60.0f, 16.0f };
info.labels_upper_bound_.HostVector()
= { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
info.weights_.HostVector() = std::vector<bst_float>();
info.data_split_mode = data_split_mode;
HostDeviceVector<bst_float> preds(4, std::log(64));
struct TestCase {
std::string dist_type;
bst_float reference_value;
};
for (const auto& test_case : std::vector<TestCase>{ {"normal", 2.1508f}, {"logistic", 2.1804f},
{"extreme", 2.0706f} }) {
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
{"aft_loss_distribution_scale", "1.0"} });
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
}
}
inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.num_row_ = 4;
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
info.weights_.HostVector() = std::vector<bst_float>();
info.data_split_mode = data_split_mode;
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
info.labels_lower_bound_.HostVector()[2] = 70.0f;
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_lower_bound_.HostVector()[0] = 70.0f;
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
}
} // namespace common
} // namespace xgboost

View File

@ -65,7 +65,7 @@ class BaseFederatedTest : public ::testing::Test {
void TearDown() override { server_.reset(nullptr); } void TearDown() override { server_.reset(nullptr); }
static int const kWorldSize{3}; static int constexpr kWorldSize{3};
std::unique_ptr<ServerForTest> server_; std::unique_ptr<ServerForTest> server_;
}; };

View File

@ -70,7 +70,7 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e
class FederatedLearnerTest : public ::testing::TestWithParam<std::string> { class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
std::unique_ptr<ServerForTest> server_; std::unique_ptr<ServerForTest> server_;
static int const kWorldSize{3}; static int constexpr kWorldSize{3};
protected: protected:
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); } void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }

View File

@ -0,0 +1,243 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../metric/test_auc.h"
#include "../metric/test_elementwise_metric.h"
#include "../metric/test_multiclass_metric.h"
#include "../metric/test_rank_metric.h"
#include "../metric/test_survival_metric.h"
#include "helpers.h"
namespace {
class FederatedMetricTest : public xgboost::BaseFederatedTest {};
} // anonymous namespace
namespace xgboost {
namespace metric {
TEST_F(FederatedMetricTest, BinaryAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, BinaryAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiClassAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiClassAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, RankingAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, RankingAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, PRAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, PRAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiClassPRAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiClassPRAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, RankingPRAUCRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, RankingPRAUCColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, RMSERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, RMSEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, RMSLERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, RMSLEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MAERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MAEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MAPERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MAPEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MPHERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MPHEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, LogLossRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, LogLossColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, ErrorRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, ErrorColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, PoissonNegLogLikRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, PoissonNegLogLikColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiRMSERowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiRMSEColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, QuantileRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, QuantileColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiClassErrorRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiClassErrorColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MultiClassLogLossRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MultiClassLogLossColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, PrecisionRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, PrecisionColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, NDCGRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, NDCGColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, MAPRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, MAPColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, NDCGExpGainRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, NDCGExpGainColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
DataSplitMode::kCol);
}
} // namespace metric
} // namespace xgboost
namespace xgboost {
namespace common {
TEST_F(FederatedMetricTest, AFTNegLogLikRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, AFTNegLogLikColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
DataSplitMode::kCol);
}
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyRowSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
DataSplitMode::kRow);
}
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyColumnSplit) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
DataSplitMode::kCol);
}
} // namespace common
} // namespace xgboost