Make sure metrics work with federated learning (#9037)
This commit is contained in:
parent
ef13dd31b1
commit
42d100de18
62
src/collective/aggregator.h
Normal file
62
src/collective/aggregator.h
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost contributors
|
||||||
|
*
|
||||||
|
* Higher level functions built on top the Communicator API, taking care of behavioral differences
|
||||||
|
* between row-split vs column-split distributed training, and horizontal vs vertical federated
|
||||||
|
* learning.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "communicator-inl.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Apply the given function where the labels are.
|
||||||
|
*
|
||||||
|
* Normally all the workers have access to the labels, so the function is just applied locally. In
|
||||||
|
* vertical federated learning, we assume labels are only available on worker 0, so the function is
|
||||||
|
* applied there, with the results broadcast to other workers.
|
||||||
|
*
|
||||||
|
* @tparam Function The function used to calculate the results.
|
||||||
|
* @tparam Args Arguments to the function.
|
||||||
|
* @param info MetaInfo about the DMatrix.
|
||||||
|
* @param buffer The buffer storing the results.
|
||||||
|
* @param size The size of the buffer.
|
||||||
|
* @param function The function used to calculate the results.
|
||||||
|
* @param args Arguments to the function.
|
||||||
|
*/
|
||||||
|
template <typename Function, typename... Args>
|
||||||
|
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function,
|
||||||
|
Args&&... args) {
|
||||||
|
if (info.IsVerticalFederated()) {
|
||||||
|
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||||
|
// broadcast to other workers.
|
||||||
|
std::vector<char> message(1024);
|
||||||
|
if (collective::GetRank() == 0) {
|
||||||
|
try {
|
||||||
|
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||||
|
} catch (dmlc::Error& e) {
|
||||||
|
strncpy(&message[0], e.what(), message.size());
|
||||||
|
message.back() = '\0';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
collective::Broadcast(&message[0], message.size(), 0);
|
||||||
|
if (strlen(&message[0]) == 0) {
|
||||||
|
collective::Broadcast(buffer, size, 0);
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << &message[0];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
@ -34,6 +34,7 @@
|
|||||||
#include <utility> // for pair, as_const, move, swap
|
#include <utility> // for pair, as_const, move, swap
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "collective/aggregator.h" // for ApplyWithLabels
|
||||||
#include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed
|
#include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed
|
||||||
#include "collective/communicator.h" // for Operation
|
#include "collective/communicator.h" // for Operation
|
||||||
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
|
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||||
@ -859,22 +860,10 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||||
// Special handling for vertical federated learning.
|
base_score->Reshape(1);
|
||||||
if (info.IsVerticalFederated()) {
|
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
|
||||||
// We assume labels are only available on worker 0, so the estimation is calculated there
|
sizeof(bst_float) * base_score->Size(),
|
||||||
// and broadcast to other workers.
|
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
||||||
if (collective::GetRank() == 0) {
|
|
||||||
UsePtr(obj_)->InitEstimation(info, base_score);
|
|
||||||
collective::Broadcast(base_score->Data()->HostPointer(),
|
|
||||||
sizeof(bst_float) * base_score->Size(), 0);
|
|
||||||
} else {
|
|
||||||
base_score->Reshape(1);
|
|
||||||
collective::Broadcast(base_score->Data()->HostPointer(),
|
|
||||||
sizeof(bst_float) * base_score->Size(), 0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
UsePtr(obj_)->InitEstimation(info, base_score);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1486,24 +1475,10 @@ class LearnerImpl : public LearnerIO {
|
|||||||
private:
|
private:
|
||||||
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
||||||
HostDeviceVector<GradientPair>* out_gpair) {
|
HostDeviceVector<GradientPair>* out_gpair) {
|
||||||
// Special handling for vertical federated learning.
|
out_gpair->Resize(preds.Size());
|
||||||
if (info.IsVerticalFederated()) {
|
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
|
||||||
// We assume labels are only available on worker 0, so the gradients are calculated there
|
out_gpair->Size() * sizeof(GradientPair),
|
||||||
// and broadcast to other workers.
|
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
|
||||||
if (collective::GetRank() == 0) {
|
|
||||||
obj_->GetGradient(preds, info, iteration, out_gpair);
|
|
||||||
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
|
|
||||||
0);
|
|
||||||
} else {
|
|
||||||
CHECK_EQ(info.labels.Size(), 0)
|
|
||||||
<< "In vertical federated learning, labels should only be on the first worker";
|
|
||||||
out_gpair->Resize(preds.Size());
|
|
||||||
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
|
|
||||||
0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
obj_->GetGradient(preds, info, iteration, out_gpair);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief random number transformation seed. */
|
/*! \brief random number transformation seed. */
|
||||||
|
|||||||
@ -270,7 +270,9 @@ class EvalAUC : public MetricNoCache {
|
|||||||
}
|
}
|
||||||
// We use the global size to handle empty dataset.
|
// We use the global size to handle empty dataset.
|
||||||
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
||||||
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
|
if (!info.IsVerticalFederated()) {
|
||||||
|
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
|
||||||
|
}
|
||||||
if (meta[0] == 0) {
|
if (meta[0] == 0) {
|
||||||
// Empty across all workers, which is not supported.
|
// Empty across all workers, which is not supported.
|
||||||
auc = std::numeric_limits<double>::quiet_NaN();
|
auc = std::numeric_limits<double>::quiet_NaN();
|
||||||
|
|||||||
@ -9,6 +9,8 @@
|
|||||||
#include <memory> // shared_ptr
|
#include <memory> // shared_ptr
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "../collective/aggregator.h"
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "xgboost/metric.h"
|
#include "xgboost/metric.h"
|
||||||
|
|
||||||
@ -20,7 +22,12 @@ class MetricNoCache : public Metric {
|
|||||||
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
|
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
|
||||||
|
|
||||||
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
|
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
|
||||||
return this->Eval(predts, p_fmat->Info());
|
double result{0.0};
|
||||||
|
auto const& info = p_fmat->Info();
|
||||||
|
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
|
||||||
|
result = this->Eval(predts, info);
|
||||||
|
});
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -28,9 +28,8 @@
|
|||||||
#include <algorithm> // for stable_sort, copy, fill_n, min, max
|
#include <algorithm> // for stable_sort, copy, fill_n, min, max
|
||||||
#include <array> // for array
|
#include <array> // for array
|
||||||
#include <cmath> // for log, sqrt
|
#include <cmath> // for log, sqrt
|
||||||
#include <cstddef> // for size_t, std
|
|
||||||
#include <cstdint> // for uint32_t
|
|
||||||
#include <functional> // for less, greater
|
#include <functional> // for less, greater
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
#include <map> // for operator!=, _Rb_tree_const_iterator
|
#include <map> // for operator!=, _Rb_tree_const_iterator
|
||||||
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
|
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
|
||||||
#include <numeric> // for accumulate
|
#include <numeric> // for accumulate
|
||||||
@ -39,15 +38,11 @@
|
|||||||
#include <utility> // for pair, make_pair
|
#include <utility> // for pair, make_pair
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce
|
#include "../collective/aggregator.h" // for ApplyWithLabels
|
||||||
#include "../collective/communicator.h" // for Operation
|
|
||||||
#include "../common/algorithm.h" // for ArgSort, Sort
|
#include "../common/algorithm.h" // for ArgSort, Sort
|
||||||
#include "../common/linalg_op.h" // for cbegin, cend
|
#include "../common/linalg_op.h" // for cbegin, cend
|
||||||
#include "../common/math.h" // for CmpFirst
|
#include "../common/math.h" // for CmpFirst
|
||||||
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
|
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
|
||||||
#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName
|
|
||||||
#include "../common/threading_utils.h" // for ParallelFor
|
|
||||||
#include "../common/transform_iterator.h" // for IndexTransformIter
|
|
||||||
#include "dmlc/common.h" // for OMPException
|
#include "dmlc/common.h" // for OMPException
|
||||||
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
|
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
|
||||||
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
|
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
|
||||||
@ -59,7 +54,6 @@
|
|||||||
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
|
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
|
||||||
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
|
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
|
||||||
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
|
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
|
||||||
#include "xgboost/span.h" // for Span, operator!=
|
|
||||||
#include "xgboost/string_view.h" // for StringView
|
#include "xgboost/string_view.h" // for StringView
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -385,15 +379,19 @@ class EvalRankWithCache : public Metric {
|
|||||||
}
|
}
|
||||||
|
|
||||||
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
|
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
|
||||||
|
double result{0.0};
|
||||||
auto const& info = p_fmat->Info();
|
auto const& info = p_fmat->Info();
|
||||||
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
|
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
|
||||||
if (p_cache->Param() != param_) {
|
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
|
||||||
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
|
if (p_cache->Param() != param_) {
|
||||||
}
|
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
|
||||||
CHECK(p_cache->Param() == param_);
|
}
|
||||||
CHECK_EQ(preds.Size(), info.labels.Size());
|
CHECK(p_cache->Param() == param_);
|
||||||
|
CHECK_EQ(preds.Size(), info.labels.Size());
|
||||||
|
|
||||||
return this->Eval(preds, info, p_cache);
|
result = this->Eval(preds, info, p_cache);
|
||||||
|
});
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
|
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
|
||||||
|
|||||||
@ -189,7 +189,9 @@ double GetMultiMetricEval(xgboost::Metric* metric,
|
|||||||
info.weights_.HostVector() = weights;
|
info.weights_.HostVector() = weights;
|
||||||
info.group_ptr_ = groups;
|
info.group_ptr_ = groups;
|
||||||
info.data_split_mode = data_split_mode;
|
info.data_split_mode = data_split_mode;
|
||||||
|
if (info.IsVerticalFederated() && xgboost::collective::GetRank() != 0) {
|
||||||
|
info.labels.Reshape(0);
|
||||||
|
}
|
||||||
return metric->Evaluate(preds, p_fmat);
|
return metric->Evaluate(preds, p_fmat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,109 +2,13 @@
|
|||||||
* Copyright (c) by Contributors 2020
|
* Copyright (c) by Contributors 2020
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <cmath>
|
#include "test_survival_metric.h"
|
||||||
#include "xgboost/metric.h"
|
#include "xgboost/metric.h"
|
||||||
#include "../helpers.h"
|
|
||||||
#include "../../../src/common/survival_util.h"
|
|
||||||
|
|
||||||
/** Tests for Survival metrics that should run both on CPU and GPU **/
|
/** Tests for Survival metrics that should run both on CPU and GPU **/
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
namespace {
|
|
||||||
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
|
|
||||||
auto ctx = CreateEmptyGenericParam(device);
|
|
||||||
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
|
|
||||||
metric->Configure(Args{});
|
|
||||||
|
|
||||||
HostDeviceVector<float> predts;
|
|
||||||
auto p_fmat = EmptyDMatrix();
|
|
||||||
MetaInfo& info = p_fmat->Info();
|
|
||||||
auto &h_predts = predts.HostVector();
|
|
||||||
|
|
||||||
SimpleLCG lcg;
|
|
||||||
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
|
|
||||||
|
|
||||||
size_t n_samples = 2048;
|
|
||||||
h_predts.resize(n_samples);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < n_samples; ++i) {
|
|
||||||
h_predts[i] = dist(&lcg);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto &h_upper = info.labels_upper_bound_.HostVector();
|
|
||||||
auto &h_lower = info.labels_lower_bound_.HostVector();
|
|
||||||
h_lower.resize(n_samples);
|
|
||||||
h_upper.resize(n_samples);
|
|
||||||
for (size_t i = 0; i < n_samples; ++i) {
|
|
||||||
h_lower[i] = 1;
|
|
||||||
h_upper[i] = 10;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto result = metric->Evaluate(predts, p_fmat);
|
|
||||||
for (size_t i = 0; i < 8; ++i) {
|
|
||||||
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
|
||||||
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Test aggregate output from the AFT metric over a small test data set.
|
|
||||||
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
|
|
||||||
**/
|
|
||||||
auto p_fmat = EmptyDMatrix();
|
|
||||||
MetaInfo& info = p_fmat->Info();
|
|
||||||
info.num_row_ = 4;
|
|
||||||
info.labels_lower_bound_.HostVector()
|
|
||||||
= { 100.0f, 0.0f, 60.0f, 16.0f };
|
|
||||||
info.labels_upper_bound_.HostVector()
|
|
||||||
= { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
|
|
||||||
info.weights_.HostVector() = std::vector<bst_float>();
|
|
||||||
info.data_split_mode = data_split_mode;
|
|
||||||
HostDeviceVector<bst_float> preds(4, std::log(64));
|
|
||||||
|
|
||||||
struct TestCase {
|
|
||||||
std::string dist_type;
|
|
||||||
bst_float reference_value;
|
|
||||||
};
|
|
||||||
for (const auto& test_case : std::vector<TestCase>{ {"normal", 2.1508f}, {"logistic", 2.1804f},
|
|
||||||
{"extreme", 2.0706f} }) {
|
|
||||||
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
|
|
||||||
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
|
|
||||||
{"aft_loss_distribution_scale", "1.0"} });
|
|
||||||
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
|
||||||
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
|
||||||
|
|
||||||
auto p_fmat = EmptyDMatrix();
|
|
||||||
MetaInfo& info = p_fmat->Info();
|
|
||||||
info.num_row_ = 4;
|
|
||||||
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
|
|
||||||
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
|
|
||||||
info.weights_.HostVector() = std::vector<bst_float>();
|
|
||||||
info.data_split_mode = data_split_mode;
|
|
||||||
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
|
|
||||||
|
|
||||||
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
|
|
||||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
|
|
||||||
info.labels_lower_bound_.HostVector()[2] = 70.0f;
|
|
||||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
|
||||||
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
|
|
||||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
|
||||||
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
|
|
||||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
|
||||||
info.labels_lower_bound_.HostVector()[0] = 70.0f;
|
|
||||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
|
|
||||||
|
|
||||||
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
|
|
||||||
}
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) { VerifyAFTNegLogLik(); }
|
TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) { VerifyAFTNegLogLik(); }
|
||||||
|
|
||||||
TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikRowSplit) {
|
TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikRowSplit) {
|
||||||
@ -140,6 +44,5 @@ TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) {
|
|||||||
|
|
||||||
CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX);
|
CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
107
tests/cpp/metric/test_survival_metric.h
Normal file
107
tests/cpp/metric/test_survival_metric.h
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020-2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "../../../src/common/survival_util.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "xgboost/metric.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
|
||||||
|
auto ctx = CreateEmptyGenericParam(device);
|
||||||
|
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
|
||||||
|
metric->Configure(Args{});
|
||||||
|
|
||||||
|
HostDeviceVector<float> predts;
|
||||||
|
auto p_fmat = EmptyDMatrix();
|
||||||
|
MetaInfo& info = p_fmat->Info();
|
||||||
|
auto &h_predts = predts.HostVector();
|
||||||
|
|
||||||
|
SimpleLCG lcg;
|
||||||
|
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
|
||||||
|
|
||||||
|
size_t n_samples = 2048;
|
||||||
|
h_predts.resize(n_samples);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_samples; ++i) {
|
||||||
|
h_predts[i] = dist(&lcg);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &h_upper = info.labels_upper_bound_.HostVector();
|
||||||
|
auto &h_lower = info.labels_lower_bound_.HostVector();
|
||||||
|
h_lower.resize(n_samples);
|
||||||
|
h_upper.resize(n_samples);
|
||||||
|
for (size_t i = 0; i < n_samples; ++i) {
|
||||||
|
h_lower[i] = 1;
|
||||||
|
h_upper[i] = 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = metric->Evaluate(predts, p_fmat);
|
||||||
|
for (size_t i = 0; i < 8; ++i) {
|
||||||
|
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
||||||
|
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test aggregate output from the AFT metric over a small test data set.
|
||||||
|
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
|
||||||
|
**/
|
||||||
|
auto p_fmat = EmptyDMatrix();
|
||||||
|
MetaInfo& info = p_fmat->Info();
|
||||||
|
info.num_row_ = 4;
|
||||||
|
info.labels_lower_bound_.HostVector()
|
||||||
|
= { 100.0f, 0.0f, 60.0f, 16.0f };
|
||||||
|
info.labels_upper_bound_.HostVector()
|
||||||
|
= { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
|
||||||
|
info.weights_.HostVector() = std::vector<bst_float>();
|
||||||
|
info.data_split_mode = data_split_mode;
|
||||||
|
HostDeviceVector<bst_float> preds(4, std::log(64));
|
||||||
|
|
||||||
|
struct TestCase {
|
||||||
|
std::string dist_type;
|
||||||
|
bst_float reference_value;
|
||||||
|
};
|
||||||
|
for (const auto& test_case : std::vector<TestCase>{ {"normal", 2.1508f}, {"logistic", 2.1804f},
|
||||||
|
{"extreme", 2.0706f} }) {
|
||||||
|
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
|
||||||
|
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
|
||||||
|
{"aft_loss_distribution_scale", "1.0"} });
|
||||||
|
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
||||||
|
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||||
|
|
||||||
|
auto p_fmat = EmptyDMatrix();
|
||||||
|
MetaInfo& info = p_fmat->Info();
|
||||||
|
info.num_row_ = 4;
|
||||||
|
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
|
||||||
|
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
|
||||||
|
info.weights_.HostVector() = std::vector<bst_float>();
|
||||||
|
info.data_split_mode = data_split_mode;
|
||||||
|
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
|
||||||
|
|
||||||
|
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
|
||||||
|
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
|
||||||
|
info.labels_lower_bound_.HostVector()[2] = 70.0f;
|
||||||
|
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||||
|
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
|
||||||
|
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||||
|
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
|
||||||
|
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||||
|
info.labels_lower_bound_.HostVector()[0] = 70.0f;
|
||||||
|
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
|
||||||
|
|
||||||
|
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
@ -65,7 +65,7 @@ class BaseFederatedTest : public ::testing::Test {
|
|||||||
|
|
||||||
void TearDown() override { server_.reset(nullptr); }
|
void TearDown() override { server_.reset(nullptr); }
|
||||||
|
|
||||||
static int const kWorldSize{3};
|
static int constexpr kWorldSize{3};
|
||||||
std::unique_ptr<ServerForTest> server_;
|
std::unique_ptr<ServerForTest> server_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -70,7 +70,7 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e
|
|||||||
|
|
||||||
class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
||||||
std::unique_ptr<ServerForTest> server_;
|
std::unique_ptr<ServerForTest> server_;
|
||||||
static int const kWorldSize{3};
|
static int constexpr kWorldSize{3};
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
||||||
|
|||||||
243
tests/cpp/plugin/test_federated_metrics.cc
Normal file
243
tests/cpp/plugin/test_federated_metrics.cc
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2023 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "../metric/test_auc.h"
|
||||||
|
#include "../metric/test_elementwise_metric.h"
|
||||||
|
#include "../metric/test_multiclass_metric.h"
|
||||||
|
#include "../metric/test_rank_metric.h"
|
||||||
|
#include "../metric/test_survival_metric.h"
|
||||||
|
#include "helpers.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class FederatedMetricTest : public xgboost::BaseFederatedTest {};
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace metric {
|
||||||
|
TEST_F(FederatedMetricTest, BinaryAUCRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, BinaryAUCColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiClassAUCRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiClassAUCColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, RankingAUCRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, RankingAUCColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, PRAUCRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, PRAUCColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiClassPRAUCRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiClassPRAUCColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, RankingPRAUCRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, RankingPRAUCColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, RMSERowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, RMSEColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, RMSLERowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, RMSLEColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MAERowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MAEColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MAPERowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MAPEColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MPHERowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MPHEColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, LogLossRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, LogLossColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, ErrorRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, ErrorColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, PoissonNegLogLikRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, PoissonNegLogLikColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiRMSERowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiRMSEColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, QuantileRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, QuantileColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiClassErrorRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiClassErrorColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiClassLogLossRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MultiClassLogLossColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, PrecisionRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, PrecisionColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, NDCGRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, NDCGColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MAPRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, MAPColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, NDCGExpGainRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, NDCGExpGainColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
} // namespace metric
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
TEST_F(FederatedMetricTest, AFTNegLogLikRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, AFTNegLogLikColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyRowSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
|
||||||
|
DataSplitMode::kRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedMetricTest, IntervalRegressionAccuracyColumnSplit) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy,
|
||||||
|
DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user