Require context in aggregators. (#10075)

This commit is contained in:
Jiaming Yuan
2024-02-28 03:12:42 +08:00
committed by GitHub
parent 761845f594
commit 5ac233280e
23 changed files with 190 additions and 144 deletions

View File

@@ -1,22 +1,21 @@
/**
* Copyright 2023 by XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*
* Higher level functions built on top the Communicator API, taking care of behavioral differences
* between row-split vs column-split distributed training, and horizontal vs vertical federated
* learning.
*/
#pragma once
#include <xgboost/data.h>
#include <limits>
#include <string>
#include <utility>
#include <vector>
#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo
namespace xgboost {
namespace collective {
namespace xgboost::collective {
/**
* @brief Apply the given function where the labels are.
@@ -31,15 +30,16 @@ namespace collective {
* @param size The size of the buffer.
* @param function The function used to calculate the results.
*/
template <typename Function>
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) {
template <typename FN>
void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size,
FN&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)();
std::forward<FN>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
@@ -52,7 +52,7 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
LOG(FATAL) << &message[0];
}
} else {
std::forward<Function>(function)();
std::forward<FN>(function)();
}
}
@@ -70,7 +70,8 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
Function&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
@@ -114,7 +115,9 @@ void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function
* @return The global max of the input.
*/
template <typename T>
T GlobalMax(MetaInfo const& info, T value) {
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const*,
MetaInfo const& info,
T value) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kMax>(&value, 1);
}
@@ -132,16 +135,18 @@ T GlobalMax(MetaInfo const& info, T value) {
* @param values Pointer to the inputs to sum.
* @param size Number of values to sum.
*/
template <typename T>
void GlobalSum(MetaInfo const& info, T* values, size_t size) {
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(values, size);
collective::Allreduce<collective::Operation::kSum>(values.Values().data(), values.Size());
}
return Success();
}
template <typename Container>
void GlobalSum(MetaInfo const& info, Container* values) {
GlobalSum(info, values->data(), values->size());
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) {
return GlobalSum(ctx, info, values->data(), values->size());
}
/**
@@ -157,9 +162,10 @@ void GlobalSum(MetaInfo const& info, Container* values) {
* @return The global ratio of the two inputs.
*/
template <typename T>
T GlobalRatio(MetaInfo const& info, T dividend, T divisor) {
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
std::array<T, 2> results{dividend, divisor};
GlobalSum(info, &results);
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
collective::SafeColl(rc);
std::tie(dividend, divisor) = std::tuple_cat(results);
if (divisor <= 0) {
return std::numeric_limits<T>::quiet_NaN();
@@ -167,6 +173,4 @@ T GlobalRatio(MetaInfo const& info, T dividend, T divisor) {
return dividend / divisor;
}
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective