Make sure metrics work with federated learning (#9037)
This commit is contained in:
@@ -270,7 +270,9 @@ class EvalAUC : public MetricNoCache {
|
||||
}
|
||||
// We use the global size to handle empty dataset.
|
||||
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
||||
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
|
||||
if (!info.IsVerticalFederated()) {
|
||||
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
|
||||
}
|
||||
if (meta[0] == 0) {
|
||||
// Empty across all workers, which is not supported.
|
||||
auc = std::numeric_limits<double>::quiet_NaN();
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
#include <memory> // shared_ptr
|
||||
#include <string>
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h"
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
@@ -20,7 +22,12 @@ class MetricNoCache : public Metric {
|
||||
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
|
||||
|
||||
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
|
||||
return this->Eval(predts, p_fmat->Info());
|
||||
double result{0.0};
|
||||
auto const& info = p_fmat->Info();
|
||||
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
|
||||
result = this->Eval(predts, info);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -28,9 +28,8 @@
|
||||
#include <algorithm> // for stable_sort, copy, fill_n, min, max
|
||||
#include <array> // for array
|
||||
#include <cmath> // for log, sqrt
|
||||
#include <cstddef> // for size_t, std
|
||||
#include <cstdint> // for uint32_t
|
||||
#include <functional> // for less, greater
|
||||
#include <limits> // for numeric_limits
|
||||
#include <map> // for operator!=, _Rb_tree_const_iterator
|
||||
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
|
||||
#include <numeric> // for accumulate
|
||||
@@ -39,15 +38,11 @@
|
||||
#include <utility> // for pair, make_pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce
|
||||
#include "../collective/communicator.h" // for Operation
|
||||
#include "../collective/aggregator.h" // for ApplyWithLabels
|
||||
#include "../common/algorithm.h" // for ArgSort, Sort
|
||||
#include "../common/linalg_op.h" // for cbegin, cend
|
||||
#include "../common/math.h" // for CmpFirst
|
||||
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
|
||||
#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName
|
||||
#include "../common/threading_utils.h" // for ParallelFor
|
||||
#include "../common/transform_iterator.h" // for IndexTransformIter
|
||||
#include "dmlc/common.h" // for OMPException
|
||||
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
|
||||
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
|
||||
@@ -59,7 +54,6 @@
|
||||
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
|
||||
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
|
||||
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
|
||||
#include "xgboost/span.h" // for Span, operator!=
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace {
|
||||
@@ -385,15 +379,19 @@ class EvalRankWithCache : public Metric {
|
||||
}
|
||||
|
||||
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
|
||||
double result{0.0};
|
||||
auto const& info = p_fmat->Info();
|
||||
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
|
||||
if (p_cache->Param() != param_) {
|
||||
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
|
||||
}
|
||||
CHECK(p_cache->Param() == param_);
|
||||
CHECK_EQ(preds.Size(), info.labels.Size());
|
||||
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
|
||||
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
|
||||
if (p_cache->Param() != param_) {
|
||||
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
|
||||
}
|
||||
CHECK(p_cache->Param() == param_);
|
||||
CHECK_EQ(preds.Size(), info.labels.Size());
|
||||
|
||||
return this->Eval(preds, info, p_cache);
|
||||
result = this->Eval(preds, info, p_cache);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
|
||||
|
||||
Reference in New Issue
Block a user