Make sure metrics work with federated learning (#9037)

This commit is contained in:
Rong Ou
2023-04-19 00:39:11 -07:00
committed by GitHub
parent ef13dd31b1
commit 42d100de18
11 changed files with 451 additions and 152 deletions

View File

@@ -270,7 +270,9 @@ class EvalAUC : public MetricNoCache {
}
// We use the global size to handle empty dataset.
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
if (!info.IsVerticalFederated()) {
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
}
if (meta[0] == 0) {
// Empty across all workers, which is not supported.
auc = std::numeric_limits<double>::quiet_NaN();

View File

@@ -9,6 +9,8 @@
#include <memory> // shared_ptr
#include <string>
#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "xgboost/metric.h"
@@ -20,7 +22,12 @@ class MetricNoCache : public Metric {
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
return this->Eval(predts, p_fmat->Info());
double result{0.0};
auto const& info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
result = this->Eval(predts, info);
});
return result;
}
};

View File

@@ -28,9 +28,8 @@
#include <algorithm> // for stable_sort, copy, fill_n, min, max
#include <array> // for array
#include <cmath> // for log, sqrt
#include <cstddef> // for size_t, std
#include <cstdint> // for uint32_t
#include <functional> // for less, greater
#include <limits> // for numeric_limits
#include <map> // for operator!=, _Rb_tree_const_iterator
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
#include <numeric> // for accumulate
@@ -39,15 +38,11 @@
#include <utility> // for pair, make_pair
#include <vector> // for vector
#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce
#include "../collective/communicator.h" // for Operation
#include "../collective/aggregator.h" // for ApplyWithLabels
#include "../common/algorithm.h" // for ArgSort, Sort
#include "../common/linalg_op.h" // for cbegin, cend
#include "../common/math.h" // for CmpFirst
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/transform_iterator.h" // for IndexTransformIter
#include "dmlc/common.h" // for OMPException
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
@@ -59,7 +54,6 @@
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
#include "xgboost/span.h" // for Span, operator!=
#include "xgboost/string_view.h" // for StringView
namespace {
@@ -385,15 +379,19 @@ class EvalRankWithCache : public Metric {
}
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
double result{0.0};
auto const& info = p_fmat->Info();
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
if (p_cache->Param() != param_) {
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
}
CHECK(p_cache->Param() == param_);
CHECK_EQ(preds.Size(), info.labels.Size());
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
if (p_cache->Param() != param_) {
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
}
CHECK(p_cache->Param() == param_);
CHECK_EQ(preds.Size(), info.labels.Size());
return this->Eval(preds, info, p_cache);
result = this->Eval(preds, info, p_cache);
});
return result;
}
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,