Support column split in gpu hist updater (#9384)
This commit is contained in:
@@ -418,7 +418,8 @@ void GPUHistEvaluator::EvaluateSplits(
|
||||
|
||||
// Reduce to get the best candidate from all workers.
|
||||
dh::LaunchN(out_splits.size(), [world_size, all_candidates, out_splits] __device__(size_t i) {
|
||||
for (auto rank = 0; rank < world_size; rank++) {
|
||||
out_splits[i] = all_candidates[i];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
|
||||
}
|
||||
});
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <cstdint> // uint32_t
|
||||
#include <limits>
|
||||
|
||||
#include "../../collective/aggregator.h"
|
||||
#include "../../common/deterministic.cuh"
|
||||
#include "../../common/device_helpers.cuh"
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
@@ -52,7 +53,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
||||
*
|
||||
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
|
||||
*/
|
||||
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair) {
|
||||
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair, MetaInfo const& info) {
|
||||
using GradientSumT = GradientPairPrecise;
|
||||
using T = typename GradientSumT::ValueT;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
@@ -64,11 +65,11 @@ GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair) {
|
||||
// Treat pair as array of 4 primitive types to allreduce
|
||||
using ReduceT = typename decltype(p.first)::ValueT;
|
||||
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<ReduceT*>(&p), 4);
|
||||
collective::GlobalSum(info, reinterpret_cast<ReduceT*>(&p), 4);
|
||||
GradientPair positive_sum{p.first}, negative_sum{p.second};
|
||||
|
||||
std::size_t total_rows = gpair.size();
|
||||
collective::Allreduce<collective::Operation::kSum>(&total_rows, 1);
|
||||
collective::GlobalSum(info, &total_rows, 1);
|
||||
|
||||
auto histogram_rounding =
|
||||
GradientSumT{common::CreateRoundingFactor<T>(
|
||||
|
||||
@@ -39,7 +39,7 @@ private:
|
||||
GradientPairPrecise to_floating_point_;
|
||||
|
||||
public:
|
||||
explicit GradientQuantiser(common::Span<GradientPair const> gpair);
|
||||
GradientQuantiser(common::Span<GradientPair const> gpair, MetaInfo const& info);
|
||||
XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const {
|
||||
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
|
||||
gpair.GetHess() * to_fixed_point_.GetHess());
|
||||
|
||||
@@ -129,7 +129,7 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
||||
int batch_idx;
|
||||
std::size_t item_idx;
|
||||
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
|
||||
auto op_res = op(ridx[item_idx], batch_info_itr[batch_idx].data);
|
||||
auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data);
|
||||
return IndexFlagTuple{static_cast<bst_uint>(item_idx), op_res, batch_idx, op_res};
|
||||
});
|
||||
size_t temp_bytes = 0;
|
||||
|
||||
Reference in New Issue
Block a user