/** * Copyright 2020-2024, XGBoost Contributors */ #include // for make_transform_iterator #include #include // uint32_t, int32_t #include "../../collective/aggregator.h" #include "../../common/deterministic.cuh" #include "../../common/device_helpers.cuh" #include "../../data/ellpack_page.cuh" #include "histogram.cuh" #include "row_partitioner.cuh" #include "xgboost/base.h" namespace xgboost::tree { namespace { struct Pair { GradientPair first; GradientPair second; }; __host__ XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) { return {lhs.first + rhs.first, lhs.second + rhs.second}; } XGBOOST_DEV_INLINE bst_feature_t FeatIdx(FeatureGroup const& group, bst_idx_t idx, std::int32_t feature_stride) { auto fidx = group.start_feature + idx % feature_stride; return fidx; } XGBOOST_DEV_INLINE bst_idx_t IterIdx(EllpackDeviceAccessor const& matrix, RowPartitioner::RowIndexT ridx, bst_feature_t fidx) { // ridx_local = ridx - base_rowid <== Row index local to each batch // entry_idx = ridx_local * row_stride <== Starting entry index for this row in the matrix // entry_idx += start_feature <== Inside a row, first column inside this feature group // idx % feature_stride <== The feaature index local to the current feature group // entry_idx += idx % feature_stride <== Final index. return (ridx - matrix.base_rowid) * matrix.row_stride + fidx; } } // anonymous namespace struct Clip : public thrust::unary_function { static XGBOOST_DEV_INLINE float Pclip(float v) { return v > 0 ? v : 0; } static XGBOOST_DEV_INLINE float Nclip(float v) { return v < 0 ? abs(v) : 0; } XGBOOST_DEV_INLINE Pair operator()(GradientPair x) const { auto pg = Pclip(x.GetGrad()); auto ph = Pclip(x.GetHess()); auto ng = Nclip(x.GetGrad()); auto nh = Nclip(x.GetHess()); return {GradientPair{pg, ph}, GradientPair{ng, nh}}; } }; /** * In algorithm 5 (see common::CreateRoundingFactor) the bound is calculated as * $max(|v_i|) * n$. Here we use the bound: * * \begin{equation} * max( fl(\sum^{V}_{v_i>0}{v_i}), fl(\sum^{V}_{v_i<0}|v_i|) ) * \end{equation} * * to avoid outliers, as the full reduction is reproducible on GPU with reduction tree. */ GradientQuantiser::GradientQuantiser(Context const* ctx, common::Span gpair, MetaInfo const& info) { using GradientSumT = GradientPairPrecise; using T = typename GradientSumT::ValueT; thrust::device_ptr gpair_beg{gpair.data()}; auto beg = thrust::make_transform_iterator(gpair_beg, Clip()); Pair p = dh::Reduce(ctx->CUDACtx()->CTP(), beg, beg + gpair.size(), Pair{}, thrust::plus{}); // Treat pair as array of 4 primitive types to allreduce using ReduceT = typename decltype(p.first)::ValueT; static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements."); auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(reinterpret_cast(&p), 4)); collective::SafeColl(rc); GradientPair positive_sum{p.first}, negative_sum{p.second}; std::size_t total_rows = gpair.size(); rc = collective::GlobalSum(ctx, info, linalg::MakeVec(&total_rows, 1)); collective::SafeColl(rc); auto histogram_rounding = GradientSumT{common::CreateRoundingFactor( std::max(positive_sum.GetGrad(), negative_sum.GetGrad()), total_rows), common::CreateRoundingFactor( std::max(positive_sum.GetHess(), negative_sum.GetHess()), total_rows)}; using IntT = typename GradientPairInt64::ValueT; /** * Factor for converting gradients from fixed-point to floating-point. */ to_floating_point_ = histogram_rounding / static_cast(static_cast(1) << (sizeof(typename GradientSumT::ValueT) * 8 - 2)); // keep 1 for sign bit /** * Factor for converting gradients from floating-point to fixed-point. For * f64: * * Precision = 64 - 1 - log2(rounding) * * rounding is calcuated as exp(m), see the rounding factor calcuation for * details. */ to_fixed_point_ = GradientSumT(static_cast(1) / to_floating_point_.GetGrad(), static_cast(1) / to_floating_point_.GetHess()); } XGBOOST_DEV_INLINE void AtomicAddGpairShared(xgboost::GradientPairInt64* dest, xgboost::GradientPairInt64 const& gpair) { auto dst_ptr = reinterpret_cast(dest); auto g = gpair.GetQuantisedGrad(); auto h = gpair.GetQuantisedHess(); AtomicAdd64As32(dst_ptr, g); AtomicAdd64As32(dst_ptr + 1, h); } // Global 64 bit integer atomics at the time of writing do not benefit from being separated into two // 32 bit atomics XGBOOST_DEV_INLINE void AtomicAddGpairGlobal(xgboost::GradientPairInt64* dest, xgboost::GradientPairInt64 const& gpair) { auto dst_ptr = reinterpret_cast(dest); auto g = gpair.GetQuantisedGrad(); auto h = gpair.GetQuantisedHess(); atomicAdd(dst_ptr, *reinterpret_cast(&g)); atomicAdd(dst_ptr + 1, *reinterpret_cast(&h)); } template class HistogramAgent { GradientPairInt64* smem_arr_; GradientPairInt64* d_node_hist_; using Idx = RowPartitioner::RowIndexT; dh::LDGIterator d_ridx_; const GradientPair* d_gpair_; const FeatureGroup group_; const EllpackDeviceAccessor& matrix_; const int feature_stride_; const std::size_t n_elements_; const GradientQuantiser& rounding_; public: __device__ HistogramAgent(GradientPairInt64* smem_arr, GradientPairInt64* __restrict__ d_node_hist, const FeatureGroup& group, const EllpackDeviceAccessor& matrix, common::Span d_ridx, const GradientQuantiser& rounding, const GradientPair* d_gpair) : smem_arr_(smem_arr), d_node_hist_(d_node_hist), d_ridx_(d_ridx.data()), group_(group), matrix_(matrix), feature_stride_(kIsDense ? group.num_features : matrix.row_stride), n_elements_(feature_stride_ * d_ridx.size()), rounding_(rounding), d_gpair_(d_gpair) {} __device__ void ProcessPartialTileShared(std::size_t offset) { for (std::size_t idx = offset + threadIdx.x; idx < std::min(offset + kBlockThreads * kItemsPerTile, n_elements_); idx += kBlockThreads) { Idx ridx = d_ridx_[idx / feature_stride_]; auto fidx = FeatIdx(group_, idx, feature_stride_); bst_bin_t compressed_bin = matrix_.gidx_iter[IterIdx(matrix_, ridx, fidx)]; if (kIsDense || compressed_bin != matrix_.NullValue()) { auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]); // Subtract start_bin to write to group-local histogram. If this is not a dense // matrix, then start_bin is 0 since featuregrouping doesn't support sparse data. if (kIsDense) { AtomicAddGpairShared( smem_arr_ + compressed_bin + this->matrix_.feature_segments[fidx] - group_.start_bin, adjusted); } else { AtomicAddGpairShared(smem_arr_ + compressed_bin - group_.start_bin, adjusted); } } } } // Instruction level parallelism by loop unrolling // Allows the kernel to pipeline many operations while waiting for global memory // Increases the throughput of this kernel significantly __device__ void ProcessFullTileShared(std::size_t offset) { std::size_t idx[kItemsPerThread]; Idx ridx[kItemsPerThread]; bst_bin_t gidx[kItemsPerThread]; GradientPair gpair[kItemsPerThread]; #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { idx[i] = offset + i * kBlockThreads + threadIdx.x; } #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { ridx[i] = d_ridx_[idx[i] / feature_stride_]; } #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { gpair[i] = d_gpair_[ridx[i]]; auto fidx = FeatIdx(group_, idx[i], feature_stride_); if (kIsDense) { gidx[i] = matrix_.gidx_iter[IterIdx(matrix_, ridx[i], fidx)] + matrix_.feature_segments[fidx]; } else { gidx[i] = matrix_.gidx_iter[IterIdx(matrix_, ridx[i], fidx)]; } } #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { if ((kIsDense || gidx[i] != matrix_.NullValue())) { auto adjusted = rounding_.ToFixedPoint(gpair[i]); AtomicAddGpairShared(smem_arr_ + gidx[i] - group_.start_bin, adjusted); } } } __device__ void BuildHistogramWithShared() { dh::BlockFill(smem_arr_, group_.num_bins, GradientPairInt64{}); __syncthreads(); std::size_t offset = blockIdx.x * kItemsPerTile; while (offset + kItemsPerTile <= n_elements_) { ProcessFullTileShared(offset); offset += kItemsPerTile * gridDim.x; } ProcessPartialTileShared(offset); // Write shared memory back to global memory __syncthreads(); for (auto i : dh::BlockStrideRange(0, group_.num_bins)) { AtomicAddGpairGlobal(d_node_hist_ + group_.start_bin + i, smem_arr_[i]); } } __device__ void BuildHistogramWithGlobal() { for (auto idx : dh::GridStrideRange(static_cast(0), n_elements_)) { Idx ridx = d_ridx_[idx / feature_stride_]; auto fidx = FeatIdx(group_, idx, feature_stride_); bst_bin_t compressed_bin = matrix_.gidx_iter[IterIdx(matrix_, ridx, fidx)]; if (kIsDense || compressed_bin != matrix_.NullValue()) { auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]); if (kIsDense) { auto start_bin = this->matrix_.feature_segments[fidx]; AtomicAddGpairGlobal(d_node_hist_ + compressed_bin + start_bin, adjusted); } else { AtomicAddGpairGlobal(d_node_hist_ + compressed_bin, adjusted); } } } } }; template __global__ void __launch_bounds__(kBlockThreads) SharedMemHistKernel(const EllpackDeviceAccessor matrix, const FeatureGroupsAccessor feature_groups, common::Span d_ridx, GradientPairInt64* __restrict__ d_node_hist, const GradientPair* __restrict__ d_gpair, GradientQuantiser const rounding) { extern __shared__ char smem[]; const FeatureGroup group = feature_groups[blockIdx.y]; auto smem_arr = reinterpret_cast(smem); auto agent = HistogramAgent( smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair); if (use_shared_memory_histograms) { agent.BuildHistogramWithShared(); } else { agent.BuildHistogramWithGlobal(); } } namespace { constexpr std::int32_t kBlockThreads = 1024; constexpr std::int32_t kItemsPerThread = 8; constexpr std::int32_t ItemsPerTile() { return kBlockThreads * kItemsPerThread; } } // namespace // Use auto deduction guide to workaround compiler error. template , auto Global = SharedMemHistKernel, auto SharedDense = SharedMemHistKernel, auto Shared = SharedMemHistKernel> struct HistogramKernel { // Kernel for working with dense Ellpack using the global memory. decltype(Global) global_dense_kernel{ SharedMemHistKernel}; // Kernel for working with sparse Ellpack using the global memory. decltype(Global) global_kernel{SharedMemHistKernel}; // Kernel for working with dense Ellpack using the shared memory. decltype(Shared) shared_dense_kernel{ SharedMemHistKernel}; // Kernel for working with sparse Ellpack using the shared memory. decltype(Shared) shared_kernel{SharedMemHistKernel}; bool shared{false}; std::uint32_t grid_size{0}; std::size_t smem_size{0}; HistogramKernel(Context const* ctx, FeatureGroupsAccessor const& feature_groups, bool force_global_memory) { // Decide whether to use shared memory // Opt into maximum shared memory for the kernel if necessary std::size_t max_shared_memory = dh::MaxSharedMemoryOptin(ctx->Ordinal()); this->smem_size = sizeof(GradientPairInt64) * feature_groups.max_group_bins; this->shared = !force_global_memory && smem_size <= max_shared_memory; this->smem_size = this->shared ? this->smem_size : 0; auto init = [&](auto& kernel) { if (this->shared) { dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_memory)); } // determine the launch configuration std::int32_t num_groups = feature_groups.NumGroups(); std::int32_t n_mps = 0; dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, ctx->Ordinal())); std::int32_t n_blocks_per_mp = 0; dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel, kBlockThreads, this->smem_size)); // This gives the number of blocks to keep the device occupied Use this as the // maximum number of blocks this->grid_size = n_blocks_per_mp * n_mps; }; // Initialize all kernel instantiations for (auto& kernel : {global_dense_kernel, global_kernel, shared_dense_kernel, shared_kernel}) { init(kernel); } } }; class DeviceHistogramBuilderImpl { std::unique_ptr> kernel_{nullptr}; public: void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups, bool force_global_memory) { this->kernel_ = std::make_unique>(ctx, feature_groups, force_global_memory); if (force_global_memory) { CHECK(!this->kernel_->shared); } } void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span d_ridx, common::Span histogram, GradientQuantiser rounding) const { CHECK(kernel_); // Otherwise launch blocks such that each block has a minimum amount of work to do // There are fixed costs to launching each block, e.g. zeroing shared memory // The below amount of minimum work was found by experimentation int columns_per_group = common::DivRoundUp(matrix.row_stride, feature_groups.NumGroups()); // Average number of matrix elements processed by each group std::size_t items_per_group = d_ridx.size() * columns_per_group; // Allocate number of blocks such that each block has about kMinItemsPerBlock work // Up to a maximum where the device is saturated auto constexpr kMinItemsPerBlock = ItemsPerTile(); auto grid_size = std::min(kernel_->grid_size, static_cast(common::DivRoundUp( items_per_group, kMinItemsPerBlock))); auto launcher = [&](auto kernel) { dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT static_cast(kBlockThreads), kernel_->smem_size, ctx->Stream()}( kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding); }; if (!this->kernel_->shared) { CHECK_EQ(this->kernel_->smem_size, 0); if (matrix.is_dense) { launcher(this->kernel_->global_dense_kernel); } else { launcher(this->kernel_->global_kernel); } } else { CHECK_NE(this->kernel_->smem_size, 0); if (matrix.is_dense) { launcher(this->kernel_->shared_dense_kernel); } else { launcher(this->kernel_->shared_kernel); } } } }; DeviceHistogramBuilder::DeviceHistogramBuilder() : p_impl_{std::make_unique()} { monitor_.Init(__func__); } DeviceHistogramBuilder::~DeviceHistogramBuilder() = default; void DeviceHistogramBuilder::Reset(Context const* ctx, std::size_t max_cached_hist_nodes, FeatureGroupsAccessor const& feature_groups, bst_bin_t n_total_bins, bool force_global_memory) { this->monitor_.Start(__func__); this->p_impl_->Reset(ctx, feature_groups, force_global_memory); this->hist_.Reset(ctx, n_total_bins, max_cached_hist_nodes); this->monitor_.Stop(__func__); } void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, common::Span histogram, GradientQuantiser rounding) { this->monitor_.Start(__func__); this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding); this->monitor_.Stop(__func__); } void DeviceHistogramBuilder::AllReduceHist(Context const* ctx, MetaInfo const& info, bst_node_t nidx, std::size_t num_histograms) { this->monitor_.Start(__func__); auto d_node_hist = hist_.GetNodeHistogram(nidx); using ReduceT = typename std::remove_pointer::type::ValueT; auto rc = collective::GlobalSum( ctx, info, linalg::MakeVec(reinterpret_cast(d_node_hist.data()), d_node_hist.size() * 2 * num_histograms, ctx->Device())); SafeColl(rc); this->monitor_.Stop(__func__); } } // namespace xgboost::tree