diff --git a/src/tree/gpu_hist/feature_groups.cu b/src/tree/gpu_hist/feature_groups.cu index 27ed9bd91..52e58da7e 100644 --- a/src/tree/gpu_hist/feature_groups.cu +++ b/src/tree/gpu_hist/feature_groups.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2020 by XGBoost Contributors +/** + * Copyright 2020-2024, XGBoost Contributors */ #include @@ -8,12 +8,9 @@ #include "feature_groups.cuh" -#include "../../common/device_helpers.cuh" #include "../../common/hist_util.h" -namespace xgboost { -namespace tree { - +namespace xgboost::tree { FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense, size_t shm_size, size_t bin_size) { // Only use a single feature group for sparse matrices. @@ -59,6 +56,4 @@ void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) { max_group_bins = cuts.TotalBins(); } - -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 90c151556..cd848c1c0 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -5,8 +5,7 @@ #include #include -#include // uint32_t -#include +#include // uint32_t, int32_t #include "../../collective/aggregator.h" #include "../../common/deterministic.cuh" @@ -128,7 +127,7 @@ XGBOOST_DEV_INLINE void AtomicAddGpairGlobal(xgboost::GradientPairInt64* dest, } template + int kItemsPerTile = kBlockThreads * kItemsPerThread> class HistogramAgent { GradientPairInt64* smem_arr_; GradientPairInt64* d_node_hist_; @@ -244,53 +243,82 @@ __global__ void __launch_bounds__(kBlockThreads) extern __shared__ char smem[]; const FeatureGroup group = feature_groups[blockIdx.y]; auto smem_arr = reinterpret_cast(smem); - auto agent = HistogramAgent( - smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair); + auto agent = HistogramAgent(smem_arr, d_node_hist, group, matrix, + d_ridx, rounding, d_gpair); if (use_shared_memory_histograms) { agent.BuildHistogramWithShared(); } else { agent.BuildHistogramWithGlobal(); } } +namespace { +constexpr std::int32_t kBlockThreads = 1024; +constexpr std::int32_t kItemsPerThread = 8; +constexpr std::int32_t ItemsPerTile() { return kBlockThreads * kItemsPerThread; } +} // namespace -void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, - FeatureGroupsAccessor const& feature_groups, - common::Span gpair, - common::Span d_ridx, - common::Span histogram, GradientQuantiser rounding, - bool force_global_memory) { - // decide whether to use shared memory - int device = 0; - dh::safe_cuda(cudaGetDevice(&device)); - // opt into maximum shared memory for the kernel if necessary - size_t max_shared_memory = dh::MaxSharedMemoryOptin(device); +// Use auto deduction guide to workaround compiler error. +template , + auto Shared = SharedMemHistKernel> +struct HistogramKernel { + decltype(Global) global_kernel{SharedMemHistKernel}; + decltype(Shared) shared_kernel{SharedMemHistKernel}; + bool shared{false}; + std::uint32_t grid_size{0}; + std::size_t smem_size{0}; - size_t smem_size = - sizeof(GradientPairInt64) * feature_groups.max_group_bins; - bool shared = !force_global_memory && smem_size <= max_shared_memory; - smem_size = shared ? smem_size : 0; + HistogramKernel(Context const* ctx, FeatureGroupsAccessor const& feature_groups, + bool force_global_memory) { + // Decide whether to use shared memory + // Opt into maximum shared memory for the kernel if necessary + std::size_t max_shared_memory = dh::MaxSharedMemoryOptin(ctx->Ordinal()); - constexpr int kBlockThreads = 1024; - constexpr int kItemsPerThread = 8; - constexpr int kItemsPerTile = kBlockThreads * kItemsPerThread; + this->smem_size = sizeof(GradientPairInt64) * feature_groups.max_group_bins; + this->shared = !force_global_memory && smem_size <= max_shared_memory; + this->smem_size = this->shared ? this->smem_size : 0; - auto runit = [&, kMinItemsPerBlock = kItemsPerTile](auto kernel) { - if (shared) { - dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_memory)); - } + auto init = [&](auto& kernel) { + if (this->shared) { + dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_memory)); + } - // determine the launch configuration - int num_groups = feature_groups.NumGroups(); - int n_mps = 0; - dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); - int n_blocks_per_mp = 0; - dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel, - kBlockThreads, smem_size)); - // This gives the number of blocks to keep the device occupied - // Use this as the maximum number of blocks - unsigned grid_size = n_blocks_per_mp * n_mps; + // determine the launch configuration + std::int32_t num_groups = feature_groups.NumGroups(); + std::int32_t n_mps = 0; + dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, ctx->Ordinal())); + std::int32_t n_blocks_per_mp = 0; + dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel, + kBlockThreads, this->smem_size)); + + // This gives the number of blocks to keep the device occupied Use this as the + // maximum number of blocks + this->grid_size = n_blocks_per_mp * n_mps; + }; + + init(this->global_kernel); + init(this->shared_kernel); + } +}; + +class DeviceHistogramBuilderImpl { + std::unique_ptr> kernel_{nullptr}; + bool force_global_memory_{false}; + + public: + void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups, + bool force_global_memory) { + this->kernel_ = std::make_unique>(ctx, feature_groups, force_global_memory); + this->force_global_memory_ = force_global_memory; + } + + void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, + FeatureGroupsAccessor const& feature_groups, + common::Span gpair, + common::Span d_ridx, + common::Span histogram, GradientQuantiser rounding) { + CHECK(kernel_); // Otherwise launch blocks such that each block has a minimum amount of work to do // There are fixed costs to launching each block, e.g. zeroing shared memory // The below amount of minimum work was found by experimentation @@ -300,20 +328,41 @@ void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& // Allocate number of blocks such that each block has about kMinItemsPerBlock work // Up to a maximum where the device is saturated - grid_size = std::min(grid_size, static_cast( - common::DivRoundUp(items_per_group, kMinItemsPerBlock))); + auto constexpr kMinItemsPerBlock = ItemsPerTile(); + auto grid_size = std::min(kernel_->grid_size, static_cast(common::DivRoundUp( + items_per_group, kMinItemsPerBlock))); - dh::LaunchKernel {dim3(grid_size, num_groups), static_cast(kBlockThreads), smem_size, - ctx->Stream()} (kernel, matrix, feature_groups, d_ridx, histogram.data(), - gpair.data(), rounding); - }; - - if (shared) { - runit(SharedMemHistKernel); - } else { - runit(SharedMemHistKernel); + if (this->force_global_memory_ || !this->kernel_->shared) { + dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT + static_cast(kBlockThreads), kernel_->smem_size, + ctx->Stream()}(kernel_->global_kernel, matrix, feature_groups, d_ridx, + histogram.data(), gpair.data(), rounding); + } else { + dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT + static_cast(kBlockThreads), kernel_->smem_size, + ctx->Stream()}(kernel_->shared_kernel, matrix, feature_groups, d_ridx, + histogram.data(), gpair.data(), rounding); + } } +}; - dh::safe_cuda(cudaGetLastError()); +DeviceHistogramBuilder::DeviceHistogramBuilder() + : p_impl_{std::make_unique()} {} + +DeviceHistogramBuilder::~DeviceHistogramBuilder() = default; + +void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups, + bool force_global_memory) { + this->p_impl_->Reset(ctx, feature_groups, force_global_memory); +} + +void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx, + EllpackDeviceAccessor const& matrix, + FeatureGroupsAccessor const& feature_groups, + common::Span gpair, + common::Span ridx, + common::Span histogram, + GradientQuantiser rounding) { + this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding); } } // namespace xgboost::tree diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index 925c54893..e30f68208 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -1,17 +1,18 @@ -/*! - * Copyright 2020-2021 by XGBoost Contributors +/** + * Copyright 2020-2024, XGBoost Contributors */ #ifndef HISTOGRAM_CUH_ #define HISTOGRAM_CUH_ -#include +#include // for unique_ptr -#include "../../common/cuda_context.cuh" -#include "../../data/ellpack_page.cuh" -#include "feature_groups.cuh" - -namespace xgboost { -namespace tree { +#include "../../common/cuda_context.cuh" // for CUDAContext +#include "../../data/ellpack_page.cuh" // for EllpackDeviceAccessor +#include "feature_groups.cuh" // for FeatureGroupsAccessor +#include "xgboost/base.h" // for GradientPair, GradientPairInt64 +#include "xgboost/context.h" // for Context +#include "xgboost/span.h" // for Span +namespace xgboost::tree { /** * \brief An atomicAdd designed for gradient pair with better performance. For general * int64_t atomicAdd, one can simply cast it to unsigned long long. Exposed for testing. @@ -32,7 +33,7 @@ XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t* dst, int64_t src) { } class GradientQuantiser { -private: + private: /* Convert gradient to fixed point representation. */ GradientPairPrecise to_fixed_point_; /* Convert fixed point representation back to floating point. */ @@ -59,13 +60,23 @@ private: } }; -void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, - FeatureGroupsAccessor const& feature_groups, - common::Span gpair, - common::Span ridx, - common::Span histogram, GradientQuantiser rounding, - bool force_global_memory = false); -} // namespace tree -} // namespace xgboost +class DeviceHistogramBuilderImpl; + +class DeviceHistogramBuilder { + std::unique_ptr p_impl_; + + public: + DeviceHistogramBuilder(); + ~DeviceHistogramBuilder(); + + void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups, + bool force_global_memory); + void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, + FeatureGroupsAccessor const& feature_groups, + common::Span gpair, + common::Span ridx, + common::Span histogram, GradientQuantiser rounding); +}; +} // namespace xgboost::tree #endif // HISTOGRAM_CUH_ diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 958fa0331..aa4f8fa27 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -162,6 +162,8 @@ struct GPUHistMakerDevice { std::shared_ptr column_sampler_; MetaInfo const& info_; + DeviceHistogramBuilder histogram_; + public: EllpackPageImpl const* page{nullptr}; common::Span feature_types; @@ -256,6 +258,8 @@ struct GPUHistMakerDevice { hist.Reset(); this->InitFeatureGroupsOnce(); + + this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false); } GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) { @@ -340,9 +344,9 @@ struct GPUHistMakerDevice { void BuildHist(int nidx) { auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); - BuildGradientHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()), - feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx, - d_node_hist, *quantiser); + this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()), + feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx, + d_node_hist, *quantiser); } // Attempt to do subtraction trick diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 84cd956db..3b9e6103a 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -1,11 +1,10 @@ /** - * Copyright 2020-2023, XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include #include -#include "../../../../src/common/categorical.h" #include "../../../../src/tree/gpu_hist/histogram.cuh" #include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../../../src/tree/param.h" // TrainParam @@ -13,7 +12,7 @@ #include "../../helpers.h" namespace xgboost::tree { -void TestDeterministicHistogram(bool is_dense, int shm_size) { +void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) { Context ctx = MakeCUDACtx(0); size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16; float constexpr kLower = -1e-2, kUpper = 1e2; @@ -25,35 +24,37 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) { for (auto const& batch : matrix->GetBatches(&ctx, batch_param)) { auto* page = batch.Impl(); - tree::RowPartitioner row_partitioner(FstCU(), kRows); + tree::RowPartitioner row_partitioner(ctx.Device(), kRows); auto ridx = row_partitioner.GetRows(0); - int num_bins = kBins * kCols; + bst_bin_t num_bins = kBins * kCols; dh::device_vector histogram(num_bins); auto d_histogram = dh::ToSpan(histogram); auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); - gpair.SetDevice(FstCU()); + gpair.SetDevice(ctx.Device()); - FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, - sizeof(GradientPairInt64)); + FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64)); auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo()); - BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()), - feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx, + DeviceHistogramBuilder builder; + builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global); + builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), + feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, d_histogram, quantiser); std::vector histogram_h(num_bins); dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(), - num_bins * sizeof(GradientPairInt64), - cudaMemcpyDeviceToHost)); + num_bins * sizeof(GradientPairInt64), cudaMemcpyDeviceToHost)); - for (size_t i = 0; i < kRounds; ++i) { + for (std::size_t i = 0; i < kRounds; ++i) { dh::device_vector new_histogram(num_bins); auto d_new_histogram = dh::ToSpan(new_histogram); auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo()); - BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()), - feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx, + DeviceHistogramBuilder builder; + builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global); + builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), + feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, d_new_histogram, quantiser); std::vector new_histogram_h(num_bins); @@ -68,14 +69,16 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) { { auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); - gpair.SetDevice(FstCU()); + gpair.SetDevice(ctx.Device()); // Use a single feature group to compute the baseline. FeatureGroups single_group(page->Cuts()); dh::device_vector baseline(num_bins); - BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()), - single_group.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx, + DeviceHistogramBuilder builder; + builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), force_global); + builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), + single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, dh::ToSpan(baseline), quantiser); std::vector baseline_h(num_bins); @@ -96,7 +99,9 @@ TEST(Histogram, GPUDeterministic) { std::vector shm_sizes{48 * 1024, 64 * 1024, 160 * 1024}; for (bool is_dense : is_dense_array) { for (int shm_size : shm_sizes) { - TestDeterministicHistogram(is_dense, shm_size); + for (bool force_global : {true, false}) { + TestDeterministicHistogram(is_dense, shm_size, force_global); + } } } } @@ -136,7 +141,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { for (auto const &batch : cat_m->GetBatches(&ctx, batch_param)) { auto* page = batch.Impl(); FeatureGroups single_group(page->Cuts()); - BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), + DeviceHistogramBuilder builder; + builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false); + builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist), quantiser); } @@ -150,7 +157,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { for (auto const &batch : encode_m->GetBatches(&ctx, batch_param)) { auto* page = batch.Impl(); FeatureGroups single_group(page->Cuts()); - BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), + DeviceHistogramBuilder builder; + builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false); + builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist), quantiser); } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index cc4d9fb7f..1c156563c 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 by XGBoost contributors + * Copyright 2017-2024, XGBoost contributors */ #include #include @@ -22,12 +22,8 @@ #include "xgboost/context.h" #include "xgboost/json.h" -#if defined(XGBOOST_USE_FEDERATED) -#include "../plugin/federated/test_worker.h" // for TestFederatedGlobal -#endif // defined(XGBOOST_USE_FEDERATED) - namespace xgboost::tree { -TEST(GpuHist, DeviceHistogram) { +TEST(GpuHist, DeviceHistogramStorage) { // Ensures that node allocates correctly after reaching `kStopGrowingSize`. dh::safe_cuda(cudaSetDevice(0)); constexpr size_t kNBins = 128; @@ -102,17 +98,17 @@ void TestBuildHist(bool use_shared_memory_histograms) { xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); HostDeviceVector gpair(kNRows); - for (auto &gp : gpair.HostVector()) { - bst_float grad = dist(&gen); - bst_float hess = dist(&gen); - gp = GradientPair(grad, hess); + for (auto& gp : gpair.HostVector()) { + float grad = dist(&gen); + float hess = dist(&gen); + gp = GradientPair{grad, hess}; } - gpair.SetDevice(DeviceOrd::CUDA(0)); + gpair.SetDevice(ctx.Device()); - thrust::host_vector h_gidx_buffer (page->gidx_buffer.HostVector()); - maker.row_partitioner = std::make_unique(FstCU(), kNRows); + thrust::host_vector h_gidx_buffer(page->gidx_buffer.HostVector()); + maker.row_partitioner = std::make_unique(ctx.Device(), kNRows); - maker.hist.Init(FstCU(), page->Cuts().TotalBins()); + maker.hist.Init(ctx.Device(), page->Cuts().TotalBins()); maker.hist.AllocateHistograms({0}); maker.gpair = gpair.DeviceSpan(); @@ -121,10 +117,13 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.InitFeatureGroupsOnce(); - BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(DeviceOrd::CUDA(0)), - maker.feature_groups->DeviceAccessor(DeviceOrd::CUDA(0)), gpair.DeviceSpan(), + DeviceHistogramBuilder builder; + builder.Reset(&ctx, maker.feature_groups->DeviceAccessor(ctx.Device()), + !use_shared_memory_histograms); + builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), + maker.feature_groups->DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0), - *maker.quantiser, !use_shared_memory_histograms); + *maker.quantiser); DeviceHistogramStorage<>& d_hist = maker.hist;