Support column split in gpu hist updater (#9384)

This commit is contained in:
Rong Ou
2023-08-31 03:09:35 -07:00
committed by GitHub
parent ccfc90e4c6
commit 9bab06cbca
10 changed files with 187 additions and 28 deletions

View File

@@ -418,7 +418,8 @@ void GPUHistEvaluator::EvaluateSplits(
// Reduce to get the best candidate from all workers.
dh::LaunchN(out_splits.size(), [world_size, all_candidates, out_splits] __device__(size_t i) {
for (auto rank = 0; rank < world_size; rank++) {
out_splits[i] = all_candidates[i];
for (auto rank = 1; rank < world_size; rank++) {
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
}
});

View File

@@ -8,6 +8,7 @@
#include <cstdint> // uint32_t
#include <limits>
#include "../../collective/aggregator.h"
#include "../../common/deterministic.cuh"
#include "../../common/device_helpers.cuh"
#include "../../data/ellpack_page.cuh"
@@ -52,7 +53,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
*
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
*/
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair) {
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair, MetaInfo const& info) {
using GradientSumT = GradientPairPrecise;
using T = typename GradientSumT::ValueT;
dh::XGBCachingDeviceAllocator<char> alloc;
@@ -64,11 +65,11 @@ GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair) {
// Treat pair as array of 4 primitive types to allreduce
using ReduceT = typename decltype(p.first)::ValueT;
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<ReduceT*>(&p), 4);
collective::GlobalSum(info, reinterpret_cast<ReduceT*>(&p), 4);
GradientPair positive_sum{p.first}, negative_sum{p.second};
std::size_t total_rows = gpair.size();
collective::Allreduce<collective::Operation::kSum>(&total_rows, 1);
collective::GlobalSum(info, &total_rows, 1);
auto histogram_rounding =
GradientSumT{common::CreateRoundingFactor<T>(

View File

@@ -39,7 +39,7 @@ private:
GradientPairPrecise to_floating_point_;
public:
explicit GradientQuantiser(common::Span<GradientPair const> gpair);
GradientQuantiser(common::Span<GradientPair const> gpair, MetaInfo const& info);
XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const {
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
gpair.GetHess() * to_fixed_point_.GetHess());

View File

@@ -129,7 +129,7 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
int batch_idx;
std::size_t item_idx;
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
auto op_res = op(ridx[item_idx], batch_info_itr[batch_idx].data);
auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data);
return IndexFlagTuple{static_cast<bst_uint>(item_idx), op_res, batch_idx, op_res};
});
size_t temp_bytes = 0;