Make sure metrics work with federated learning (#9037)
This commit is contained in:
parent
ef13dd31b1
commit
42d100de18
62
src/collective/aggregator.h
Normal file
62
src/collective/aggregator.h
Normal file
@ -0,0 +1,62 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*
|
||||
* Higher level functions built on top the Communicator API, taking care of behavioral differences
|
||||
* between row-split vs column-split distributed training, and horizontal vs vertical federated
|
||||
* learning.
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Apply the given function where the labels are.
|
||||
*
|
||||
* Normally all the workers have access to the labels, so the function is just applied locally. In
|
||||
* vertical federated learning, we assume labels are only available on worker 0, so the function is
|
||||
* applied there, with the results broadcast to other workers.
|
||||
*
|
||||
* @tparam Function The function used to calculate the results.
|
||||
* @tparam Args Arguments to the function.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param buffer The buffer storing the results.
|
||||
* @param size The size of the buffer.
|
||||
* @param function The function used to calculate the results.
|
||||
* @param args Arguments to the function.
|
||||
*/
|
||||
template <typename Function, typename... Args>
|
||||
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function,
|
||||
Args&&... args) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::vector<char> message(1024);
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
} catch (dmlc::Error& e) {
|
||||
strncpy(&message[0], e.what(), message.size());
|
||||
message.back() = '\0';
|
||||
}
|
||||
}
|
||||
collective::Broadcast(&message[0], message.size(), 0);
|
||||
if (strlen(&message[0]) == 0) {
|
||||
collective::Broadcast(buffer, size, 0);
|
||||
} else {
|
||||
LOG(FATAL) << &message[0];
|
||||
}
|
||||
} else {
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -34,6 +34,7 @@
|
||||
#include <utility> // for pair, as_const, move, swap
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "collective/aggregator.h" // for ApplyWithLabels
|
||||
#include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed
|
||||
#include "collective/communicator.h" // for Operation
|
||||
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
@ -859,22 +860,10 @@ class LearnerConfiguration : public Learner {
|
||||
}
|
||||
|
||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||
// Special handling for vertical federated learning.
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the estimation is calculated there
|
||||
// and broadcast to other workers.
|
||||
if (collective::GetRank() == 0) {
|
||||
UsePtr(obj_)->InitEstimation(info, base_score);
|
||||
collective::Broadcast(base_score->Data()->HostPointer(),
|
||||
sizeof(bst_float) * base_score->Size(), 0);
|
||||
} else {
|
||||
base_score->Reshape(1);
|
||||
collective::Broadcast(base_score->Data()->HostPointer(),
|
||||
sizeof(bst_float) * base_score->Size(), 0);
|
||||
}
|
||||
} else {
|
||||
UsePtr(obj_)->InitEstimation(info, base_score);
|
||||
}
|
||||
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
|
||||
sizeof(bst_float) * base_score->Size(),
|
||||
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
||||
}
|
||||
};
|
||||
|
||||
@ -1486,24 +1475,10 @@ class LearnerImpl : public LearnerIO {
|
||||
private:
|
||||
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
||||
HostDeviceVector<GradientPair>* out_gpair) {
|
||||
// Special handling for vertical federated learning.
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the gradients are calculated there
|
||||
// and broadcast to other workers.
|
||||
if (collective::GetRank() == 0) {
|
||||
obj_->GetGradient(preds, info, iteration, out_gpair);
|
||||
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
|
||||
0);
|
||||
} else {
|
||||
CHECK_EQ(info.labels.Size(), 0)
|
||||
<< "In vertical federated learning, labels should only be on the first worker";
|
||||
out_gpair->Resize(preds.Size());
|
||||
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
|
||||
0);
|
||||
}
|
||||
} else {
|
||||
obj_->GetGradient(preds, info, iteration, out_gpair);
|
||||
}
|
||||
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
|
||||
out_gpair->Size() * sizeof(GradientPair),
|
||||
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
|
||||
}
|
||||
|
||||
/*! \brief random number transformation seed. */
|
||||
|
||||
@ -270,7 +270,9 @@ class EvalAUC : public MetricNoCache {
|
||||
}
|
||||
// We use the global size to handle empty dataset.
|
||||
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
||||
if (!info.IsVerticalFederated()) {
|
||||
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
|
||||
}
|
||||
if (meta[0] == 0) {
|
||||
// Empty across all workers, which is not supported.
|
||||
auc = std::numeric_limits<double>::quiet_NaN();
|
||||
|
||||
@ -9,6 +9,8 @@
|
||||
#include <memory> // shared_ptr
|
||||
#include <string>
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h"
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
@ -20,7 +22,12 @@ class MetricNoCache : public Metric {
|
||||
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
|
||||
|
||||
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
|
||||
return this->Eval(predts, p_fmat->Info());
|
||||
double result{0.0};
|
||||
auto const& info = p_fmat->Info();
|
||||
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
|
||||
result = this->Eval(predts, info);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -28,9 +28,8 @@
|
||||
#include <algorithm> // for stable_sort, copy, fill_n, min, max
|
||||
#include <array> // for array
|
||||
#include <cmath> // for log, sqrt
|
||||
#include <cstddef> // for size_t, std
|
||||
#include <cstdint> // for uint32_t
|
||||
#include <functional> // for less, greater
|
||||
#include <limits> // for numeric_limits
|
||||
#include <map> // for operator!=, _Rb_tree_const_iterator
|
||||
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
|
||||
#include <numeric> // for accumulate
|
||||
@ -39,15 +38,11 @@
|
||||
#include <utility> // for pair, make_pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce
|
||||
#include "../collective/communicator.h" // for Operation
|
||||
#include "../collective/aggregator.h" // for ApplyWithLabels
|
||||
#include "../common/algorithm.h" // for ArgSort, Sort
|
||||
#include "../common/linalg_op.h" // for cbegin, cend
|
||||
#include "../common/math.h" // for CmpFirst
|
||||
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
|
||||
#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName
|
||||
#include "../common/threading_utils.h" // for ParallelFor
|
||||
#include "../common/transform_iterator.h" // for IndexTransformIter
|
||||
#include "dmlc/common.h" // for OMPException
|
||||
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
|
||||
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
|
||||
@ -59,7 +54,6 @@
|
||||
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
|
||||
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
|
||||
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
|
||||
#include "xgboost/span.h" // for Span, operator!=
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace {
|
||||
@ -385,7 +379,9 @@ class EvalRankWithCache : public Metric {
|
||||
}
|
||||
|
||||
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
|
||||
double result{0.0};
|
||||
auto const& info = p_fmat->Info();
|
||||
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
|
||||
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
|
||||
if (p_cache->Param() != param_) {
|
||||
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
|
||||
@ -393,7 +389,9 @@ class EvalRankWithCache : public Metric {
|
||||
CHECK(p_cache->Param() == param_);
|
||||
CHECK_EQ(preds.Size(), info.labels.Size());
|
||||
|
||||
return this->Eval(preds, info, p_cache);
|
||||
result = this->Eval(preds, info, p_cache);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
|
||||
|
||||
@ -189,7 +189,9 @@ double GetMultiMetricEval(xgboost::Metric* metric,
|
||||
info.weights_.HostVector() = weights;
|
||||
info.group_ptr_ = groups;
|
||||
info.data_split_mode = data_split_mode;
|
||||
|
||||
if (info.IsVerticalFederated() && xgboost::collective::GetRank() != 0) {
|
||||
info.labels.Reshape(0);
|
||||
}
|
||||
return metric->Evaluate(preds, p_fmat);
|
||||
}
|
||||
|
||||
|
||||
@ -2,109 +2,13 @@
|
||||
* Copyright (c) by Contributors 2020
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <cmath>
|
||||
#include "test_survival_metric.h"
|
||||
#include "xgboost/metric.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/survival_util.h"
|
||||
|
||||
/** Tests for Survival metrics that should run both on CPU and GPU **/
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace {
|
||||
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
|
||||
auto ctx = CreateEmptyGenericParam(device);
|
||||
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
|
||||
metric->Configure(Args{});
|
||||
|
||||
HostDeviceVector<float> predts;
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
auto &h_predts = predts.HostVector();
|
||||
|
||||
SimpleLCG lcg;
|
||||
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
|
||||
|
||||
size_t n_samples = 2048;
|
||||
h_predts.resize(n_samples);
|
||||
|
||||
for (size_t i = 0; i < n_samples; ++i) {
|
||||
h_predts[i] = dist(&lcg);
|
||||
}
|
||||
|
||||
auto &h_upper = info.labels_upper_bound_.HostVector();
|
||||
auto &h_lower = info.labels_lower_bound_.HostVector();
|
||||
h_lower.resize(n_samples);
|
||||
h_upper.resize(n_samples);
|
||||
for (size_t i = 0; i < n_samples; ++i) {
|
||||
h_lower[i] = 1;
|
||||
h_upper[i] = 10;
|
||||
}
|
||||
|
||||
auto result = metric->Evaluate(predts, p_fmat);
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
|
||||
}
|
||||
}
|
||||
|
||||
void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
||||
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
/**
|
||||
* Test aggregate output from the AFT metric over a small test data set.
|
||||
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
|
||||
**/
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
info.num_row_ = 4;
|
||||
info.labels_lower_bound_.HostVector()
|
||||
= { 100.0f, 0.0f, 60.0f, 16.0f };
|
||||
info.labels_upper_bound_.HostVector()
|
||||
= { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
|
||||
info.weights_.HostVector() = std::vector<bst_float>();
|
||||
info.data_split_mode = data_split_mode;
|
||||
HostDeviceVector<bst_float> preds(4, std::log(64));
|
||||
|
||||
struct TestCase {
|
||||
std::string dist_type;
|
||||
bst_float reference_value;
|
||||
};
|
||||
for (const auto& test_case : std::vector<TestCase>{ {"normal", 2.1508f}, {"logistic", 2.1804f},
|
||||
{"extreme", 2.0706f} }) {
|
||||
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
|
||||
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
|
||||
{"aft_loss_distribution_scale", "1.0"} });
|
||||
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
||||
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
info.num_row_ = 4;
|
||||
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
|
||||
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
|
||||
info.weights_.HostVector() = std::vector<bst_float>();
|
||||
info.data_split_mode = data_split_mode;
|
||||
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
|
||||
|
||||
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
|
||||
info.labels_lower_bound_.HostVector()[2] = 70.0f;
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||
info.labels_lower_bound_.HostVector()[0] = 70.0f;
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
|
||||
|
||||
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) { VerifyAFTNegLogLik(); }
|
||||
|
||||
TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikRowSplit) {
|
||||
@ -140,6 +44,5 @@ TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) {
|
||||
|
||||
CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
107
tests/cpp/metric/test_survival_metric.h
Normal file
107
tests/cpp/metric/test_survival_metric.h
Normal file
@ -0,0 +1,107 @@
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "../../../src/common/survival_util.h"
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
|
||||
auto ctx = CreateEmptyGenericParam(device);
|
||||
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
|
||||
metric->Configure(Args{});
|
||||
|
||||
HostDeviceVector<float> predts;
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
auto &h_predts = predts.HostVector();
|
||||
|
||||
SimpleLCG lcg;
|
||||
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
|
||||
|
||||
size_t n_samples = 2048;
|
||||
h_predts.resize(n_samples);
|
||||
|
||||
for (size_t i = 0; i < n_samples; ++i) {
|
||||
h_predts[i] = dist(&lcg);
|
||||
}
|
||||
|
||||
auto &h_upper = info.labels_upper_bound_.HostVector();
|
||||
auto &h_lower = info.labels_lower_bound_.HostVector();
|
||||
h_lower.resize(n_samples);
|
||||
h_upper.resize(n_samples);
|
||||
for (size_t i = 0; i < n_samples; ++i) {
|
||||
h_lower[i] = 1;
|
||||
h_upper[i] = 10;
|
||||
}
|
||||
|
||||
auto result = metric->Evaluate(predts, p_fmat);
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
|
||||
}
|
||||
}
|
||||
|
||||
inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
||||
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
/**
|
||||
* Test aggregate output from the AFT metric over a small test data set.
|
||||
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
|
||||
**/
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
info.num_row_ = 4;
|
||||
info.labels_lower_bound_.HostVector()
|
||||
= { 100.0f, 0.0f, 60.0f, 16.0f };
|
||||
info.labels_upper_bound_.HostVector()
|
||||
= { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
|
||||
info.weights_.HostVector() = std::vector<bst_float>();
|
||||
info.data_split_mode = data_split_mode;
|
||||
HostDeviceVector<bst_float> preds(4, std::log(64));
|
||||
|
||||
struct TestCase {
|
||||
std::string dist_type;
|
||||
bst_float reference_value;
|
||||
};
|
||||
for (const auto& test_case : std::vector<TestCase>{ {"normal", 2.1508f}, {"logistic", 2.1804f},
|
||||
{"extreme", 2.0706f} }) {
|
||||
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
|
||||
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
|
||||
{"aft_loss_distribution_scale", "1.0"} });
|
||||
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
||||
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
info.num_row_ = 4;
|
||||
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
|
||||
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
|
||||
info.weights_.HostVector() = std::vector<bst_float>();
|
||||
info.data_split_mode = data_split_mode;
|
||||
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
|
||||
|
||||
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
|
||||
info.labels_lower_bound_.HostVector()[2] = 70.0f;
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||
info.labels_lower_bound_.HostVector()[0] = 70.0f;
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
|
||||
|
||||
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -65,7 +65,7 @@ class BaseFederatedTest : public ::testing::Test {
|
||||
|
||||
void TearDown() override { server_.reset(nullptr); }
|
||||
|
||||
static int const kWorldSize{3};
|
||||
static int constexpr kWorldSize{3};
|
||||
std::unique_ptr<ServerForTest> server_;
|
||||
};
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e
|
||||
|
||||
class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
||||
std::unique_ptr<ServerForTest> server_;
|
||||
static int const kWorldSize{3};
|
||||
static int constexpr kWorldSize{3};
|
||||
|
||||
protected:
|
||||
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
||||
|
||||
243
tests/cpp/plugin/test_federated_metrics.cc
Normal file
243
tests/cpp/plugin/test_federated_metrics.cc
Normal file
@ -0,0 +1,243 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../metric/test_auc.h"
|
||||
#include "../metric/test_elementwise_metric.h"
|
||||
#include "../metric/test_multiclass_metric.h"
|
||||
#include "../metric/test_rank_metric.h"
|
||||
#include "../metric/test_survival_metric.h"
|
||||
#include "helpers.h"
|
||||
|
||||
namespace {
|
||||
class FederatedMetricTest : public xgboost::BaseFederatedTest {};
|
||||
} // anonymous namespace
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
TEST_F(FederatedMetricTest, BinaryAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, BinaryAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RankingAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RankingAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PRAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PRAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassPRAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassPRAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RankingPRAUCRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RankingPRAUCColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RMSERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RMSEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RMSLERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, RMSLEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAPERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAPEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MPHERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MPHEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, LogLossRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, LogLossColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, ErrorRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, ErrorColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PoissonNegLogLikRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PoissonNegLogLikColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiRMSERowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiRMSEColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, QuantileRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, QuantileColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassErrorRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassErrorColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassLogLossRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MultiClassLogLossColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PrecisionRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, PrecisionColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, NDCGRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, NDCGColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAPRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, MAPColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, NDCGExpGainRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, NDCGExpGainColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST_F(FederatedMetricTest, AFTNegLogLikRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, AFTNegLogLikColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyRowSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
|
||||
DataSplitMode::kRow);
|
||||
}
|
||||
|
||||
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyColumnSplit) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
|
||||
DataSplitMode::kCol);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
Loading…
x
Reference in New Issue
Block a user