Make sure metrics work with federated learning (#9037)

This commit is contained in:
Rong Ou
2023-04-19 00:39:11 -07:00
committed by GitHub
parent ef13dd31b1
commit 42d100de18
11 changed files with 451 additions and 152 deletions

View File

@@ -0,0 +1,62 @@
/**
* Copyright 2023 by XGBoost contributors
*
* Higher level functions built on top the Communicator API, taking care of behavioral differences
* between row-split vs column-split distributed training, and horizontal vs vertical federated
* learning.
*/
#pragma once
#include <xgboost/data.h>
#include <string>
#include <utility>
#include <vector>
#include "communicator-inl.h"
namespace xgboost {
namespace collective {
/**
* @brief Apply the given function where the labels are.
*
* Normally all the workers have access to the labels, so the function is just applied locally. In
* vertical federated learning, we assume labels are only available on worker 0, so the function is
* applied there, with the results broadcast to other workers.
*
* @tparam Function The function used to calculate the results.
* @tparam Args Arguments to the function.
* @param info MetaInfo about the DMatrix.
* @param buffer The buffer storing the results.
* @param size The size of the buffer.
* @param function The function used to calculate the results.
* @param args Arguments to the function.
*/
template <typename Function, typename... Args>
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function,
Args&&... args) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::vector<char> message(1024);
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)(std::forward<Args>(args)...);
} catch (dmlc::Error& e) {
strncpy(&message[0], e.what(), message.size());
message.back() = '\0';
}
}
collective::Broadcast(&message[0], message.size(), 0);
if (strlen(&message[0]) == 0) {
collective::Broadcast(buffer, size, 0);
} else {
LOG(FATAL) << &message[0];
}
} else {
std::forward<Function>(function)(std::forward<Args>(args)...);
}
}
} // namespace collective
} // namespace xgboost

View File

@@ -34,6 +34,7 @@
#include <utility> // for pair, as_const, move, swap
#include <vector> // for vector
#include "collective/aggregator.h" // for ApplyWithLabels
#include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed
#include "collective/communicator.h" // for Operation
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
@@ -859,22 +860,10 @@ class LearnerConfiguration : public Learner {
}
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
// Special handling for vertical federated learning.
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the estimation is calculated there
// and broadcast to other workers.
if (collective::GetRank() == 0) {
UsePtr(obj_)->InitEstimation(info, base_score);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
} else {
base_score->Reshape(1);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
}
} else {
UsePtr(obj_)->InitEstimation(info, base_score);
}
base_score->Reshape(1);
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
}
};
@@ -1486,24 +1475,10 @@ class LearnerImpl : public LearnerIO {
private:
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) {
// Special handling for vertical federated learning.
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the gradients are calculated there
// and broadcast to other workers.
if (collective::GetRank() == 0) {
obj_->GetGradient(preds, info, iteration, out_gpair);
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
} else {
CHECK_EQ(info.labels.Size(), 0)
<< "In vertical federated learning, labels should only be on the first worker";
out_gpair->Resize(preds.Size());
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
}
} else {
obj_->GetGradient(preds, info, iteration, out_gpair);
}
out_gpair->Resize(preds.Size());
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
out_gpair->Size() * sizeof(GradientPair),
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
}
/*! \brief random number transformation seed. */

View File

@@ -270,7 +270,9 @@ class EvalAUC : public MetricNoCache {
}
// We use the global size to handle empty dataset.
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
if (!info.IsVerticalFederated()) {
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
}
if (meta[0] == 0) {
// Empty across all workers, which is not supported.
auc = std::numeric_limits<double>::quiet_NaN();

View File

@@ -9,6 +9,8 @@
#include <memory> // shared_ptr
#include <string>
#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "xgboost/metric.h"
@@ -20,7 +22,12 @@ class MetricNoCache : public Metric {
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
return this->Eval(predts, p_fmat->Info());
double result{0.0};
auto const& info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
result = this->Eval(predts, info);
});
return result;
}
};

View File

@@ -28,9 +28,8 @@
#include <algorithm> // for stable_sort, copy, fill_n, min, max
#include <array> // for array
#include <cmath> // for log, sqrt
#include <cstddef> // for size_t, std
#include <cstdint> // for uint32_t
#include <functional> // for less, greater
#include <limits> // for numeric_limits
#include <map> // for operator!=, _Rb_tree_const_iterator
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
#include <numeric> // for accumulate
@@ -39,15 +38,11 @@
#include <utility> // for pair, make_pair
#include <vector> // for vector
#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce
#include "../collective/communicator.h" // for Operation
#include "../collective/aggregator.h" // for ApplyWithLabels
#include "../common/algorithm.h" // for ArgSort, Sort
#include "../common/linalg_op.h" // for cbegin, cend
#include "../common/math.h" // for CmpFirst
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/transform_iterator.h" // for IndexTransformIter
#include "dmlc/common.h" // for OMPException
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
@@ -59,7 +54,6 @@
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
#include "xgboost/span.h" // for Span, operator!=
#include "xgboost/string_view.h" // for StringView
namespace {
@@ -385,15 +379,19 @@ class EvalRankWithCache : public Metric {
}
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
double result{0.0};
auto const& info = p_fmat->Info();
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
if (p_cache->Param() != param_) {
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
}
CHECK(p_cache->Param() == param_);
CHECK_EQ(preds.Size(), info.labels.Size());
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
if (p_cache->Param() != param_) {
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
}
CHECK(p_cache->Param() == param_);
CHECK_EQ(preds.Size(), info.labels.Size());
return this->Eval(preds, info, p_cache);
result = this->Eval(preds, info, p_cache);
});
return result;
}
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,