More collective aggregators (#9060)
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#pragma once
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -57,5 +58,72 @@ void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& fu
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the global max of the given value across all workers.
|
||||
*
|
||||
* This only applies when the data is split row-wise (horizontally). When data is split
|
||||
* column-wise (vertically), the local value is returned.
|
||||
*
|
||||
* @tparam T The type of the value.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param value The input for finding the global max.
|
||||
* @return The global max of the input.
|
||||
*/
|
||||
template <typename T>
|
||||
T GlobalMax(MetaInfo const& info, T value) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kMax>(&value, 1);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the global sum of the given values across all workers.
|
||||
*
|
||||
* This only applies when the data is split row-wise (horizontally). When data is split
|
||||
* column-wise (vertically), the original values are returned.
|
||||
*
|
||||
* @tparam T The type of the values.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param values Pointer to the inputs to sum.
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T>
|
||||
void GlobalSum(MetaInfo const& info, T* values, size_t size) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(values, size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
void GlobalSum(MetaInfo const& info, Container* values) {
|
||||
GlobalSum(info, values->data(), values->size());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the global ratio of the given two values across all workers.
|
||||
*
|
||||
* This only applies when the data is split row-wise (horizontally). When data is split
|
||||
* column-wise (vertically), the local ratio is returned.
|
||||
*
|
||||
* @tparam T The type of the values.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param dividend The dividend of the ratio.
|
||||
* @param divisor The divisor of the ratio.
|
||||
* @return The global ratio of the two inputs.
|
||||
*/
|
||||
template <typename T>
|
||||
T GlobalRatio(MetaInfo const& info, T dividend, T divisor) {
|
||||
std::array<T, 2> results{dividend, divisor};
|
||||
GlobalSum(info, &results);
|
||||
std::tie(dividend, divisor) = std::tuple_cat(results);
|
||||
if (divisor <= 0) {
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
} else {
|
||||
return dividend / divisor;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user