From 5f910cd4fff898b8fc367dbb722c47467b5e6acd Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 11 Jul 2024 03:26:30 +0800 Subject: [PATCH] [EM] Handle base idx in GPU histogram. (#10549) --- src/tree/gpu_hist/histogram.cu | 42 +++---- src/tree/gpu_hist/histogram.cuh | 1 - src/tree/gpu_hist/row_partitioner.cu | 33 ++--- src/tree/gpu_hist/row_partitioner.cuh | 25 ++-- src/tree/updater_gpu_hist.cu | 3 +- tests/cpp/tree/gpu_hist/test_histogram.cu | 115 +++++++++++++++++- .../cpp/tree/gpu_hist/test_row_partitioner.cu | 17 ++- tests/cpp/tree/test_gpu_hist.cu | 2 +- 8 files changed, 167 insertions(+), 71 deletions(-) diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index cd848c1c0..372a5c09b 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -1,8 +1,7 @@ /** * Copyright 2020-2024, XGBoost Contributors */ -#include -#include +#include // for make_transform_iterator #include #include // uint32_t, int32_t @@ -101,9 +100,8 @@ GradientQuantiser::GradientQuantiser(Context const* ctx, common::Span(1) / to_floating_point_.GetHess()); } -XGBOOST_DEV_INLINE void -AtomicAddGpairShared(xgboost::GradientPairInt64 *dest, - xgboost::GradientPairInt64 const &gpair) { +XGBOOST_DEV_INLINE void AtomicAddGpairShared(xgboost::GradientPairInt64* dest, + xgboost::GradientPairInt64 const& gpair) { auto dst_ptr = reinterpret_cast(dest); auto g = gpair.GetQuantisedGrad(); auto h = gpair.GetQuantisedHess(); @@ -131,7 +129,9 @@ template d_ridx_; + using Idx = RowPartitioner::RowIndexT; + + dh::LDGIterator d_ridx_; const GradientPair* d_gpair_; const FeatureGroup group_; const EllpackDeviceAccessor& matrix_; @@ -142,8 +142,7 @@ class HistogramAgent { public: __device__ HistogramAgent(GradientPairInt64* smem_arr, GradientPairInt64* __restrict__ d_node_hist, const FeatureGroup& group, - const EllpackDeviceAccessor& matrix, - common::Span d_ridx, + const EllpackDeviceAccessor& matrix, common::Span d_ridx, const GradientQuantiser& rounding, const GradientPair* d_gpair) : smem_arr_(smem_arr), d_node_hist_(d_node_hist), @@ -154,15 +153,15 @@ class HistogramAgent { n_elements_(feature_stride_ * d_ridx.size()), rounding_(rounding), d_gpair_(d_gpair) {} + __device__ void ProcessPartialTileShared(std::size_t offset) { for (std::size_t idx = offset + threadIdx.x; idx < std::min(offset + kBlockThreads * kItemsPerTile, n_elements_); idx += kBlockThreads) { - int ridx = d_ridx_[idx / feature_stride_]; - int gidx = - matrix_ - .gidx_iter[ridx * matrix_.row_stride + group_.start_feature + idx % feature_stride_] - - group_.start_bin; + Idx ridx = d_ridx_[idx / feature_stride_]; + Idx midx = (ridx - matrix_.base_rowid) * matrix_.row_stride + group_.start_feature + + idx % feature_stride_; + bst_bin_t gidx = matrix_.gidx_iter[midx] - group_.start_bin; if (matrix_.is_dense || gidx != matrix_.NumBins()) { auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]); AtomicAddGpairShared(smem_arr_ + gidx, adjusted); @@ -188,8 +187,8 @@ class HistogramAgent { #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { gpair[i] = d_gpair_[ridx[i]]; - gidx[i] = matrix_.gidx_iter[ridx[i] * matrix_.row_stride + group_.start_feature + - idx[i] % feature_stride_]; + gidx[i] = matrix_.gidx_iter[(ridx[i] - matrix_.base_rowid) * matrix_.row_stride + + group_.start_feature + idx[i] % feature_stride_]; } #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { @@ -200,7 +199,7 @@ class HistogramAgent { } } __device__ void BuildHistogramWithShared() { - dh::BlockFill(smem_arr_, group_.num_bins, GradientPairInt64()); + dh::BlockFill(smem_arr_, group_.num_bins, GradientPairInt64{}); __syncthreads(); std::size_t offset = blockIdx.x * kItemsPerTile; @@ -219,10 +218,9 @@ class HistogramAgent { __device__ void BuildHistogramWithGlobal() { for (auto idx : dh::GridStrideRange(static_cast(0), n_elements_)) { - int ridx = d_ridx_[idx / feature_stride_]; - int gidx = - matrix_ - .gidx_iter[ridx * matrix_.row_stride + group_.start_feature + idx % feature_stride_]; + Idx ridx = d_ridx_[idx / feature_stride_]; + bst_bin_t gidx = matrix_.gidx_iter[(ridx - matrix_.base_rowid) * matrix_.row_stride + + group_.start_feature + idx % feature_stride_]; if (matrix_.is_dense || gidx != matrix_.NumBins()) { auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]); AtomicAddGpairGlobal(d_node_hist_ + gidx, adjusted); @@ -231,8 +229,7 @@ class HistogramAgent { } }; -template +template __global__ void __launch_bounds__(kBlockThreads) SharedMemHistKernel(const EllpackDeviceAccessor matrix, const FeatureGroupsAccessor feature_groups, @@ -251,6 +248,7 @@ __global__ void __launch_bounds__(kBlockThreads) agent.BuildHistogramWithGlobal(); } } + namespace { constexpr std::int32_t kBlockThreads = 1024; constexpr std::int32_t kItemsPerThread = 8; diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index e30f68208..862821b00 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -78,5 +78,4 @@ class DeviceHistogramBuilder { common::Span histogram, GradientQuantiser rounding); }; } // namespace xgboost::tree - #endif // HISTOGRAM_CUH_ diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index 35b43d24b..f66fac489 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -1,28 +1,23 @@ -/*! - * Copyright 2017-2022 XGBoost contributors +/** + * Copyright 2017-2024, XGBoost contributors */ -#include -#include -#include +#include // for sequence -#include +#include // for vector -#include "../../common/device_helpers.cuh" +#include "../../common/cuda_context.cuh" // for CUDAContext +#include "../../common/device_helpers.cuh" // for CopyDeviceSpanToVector, ToSpan #include "row_partitioner.cuh" -namespace xgboost { -namespace tree { - -RowPartitioner::RowPartitioner(DeviceOrd device_idx, size_t num_rows) - : device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) { +namespace xgboost::tree { +RowPartitioner::RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid) + : device_idx_(ctx->Device()), ridx_(n_samples), ridx_tmp_(n_samples) { dh::safe_cuda(cudaSetDevice(device_idx_.ordinal)); - ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)}); - thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size()); + ridx_segments_.emplace_back(NodePositionInfo{Segment(0, n_samples)}); + thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid); } -RowPartitioner::~RowPartitioner() { - dh::safe_cuda(cudaSetDevice(device_idx_.ordinal)); -} +RowPartitioner::~RowPartitioner() { dh::safe_cuda(cudaSetDevice(device_idx_.ordinal)); } common::Span RowPartitioner::GetRows(bst_node_t nidx) { auto segment = ridx_segments_.at(nidx).segment; @@ -39,6 +34,4 @@ std::vector RowPartitioner::GetRowsHost(bst_node_t ni dh::CopyDeviceSpanToVector(&rows, span); return rows; } - -}; // namespace tree -}; // namespace xgboost +}; // namespace xgboost::tree diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index fde6c4dd0..636de54e6 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -1,17 +1,17 @@ -/*! - * Copyright 2017-2022 XGBoost contributors +/** + * Copyright 2017-2024, XGBoost contributors */ #pragma once #include +#include // for make_counting_iterator +#include // for make_transform_output_iterator -#include -#include +#include // for max +#include // for vector -#include "../../common/device_helpers.cuh" -#include "xgboost/base.h" -#include "xgboost/context.h" -#include "xgboost/task.h" -#include "xgboost/tree_model.h" +#include "../../common/device_helpers.cuh" // for MakeTransformIterator +#include "xgboost/base.h" // for bst_idx_t +#include "xgboost/context.h" // for Context namespace xgboost { namespace tree { @@ -223,7 +223,12 @@ class RowPartitioner { dh::PinnedMemory pinned2_; public: - RowPartitioner(DeviceOrd device_idx, size_t num_rows); + /** + * @param ctx Context for device ordinal and stream. + * @param n_samples The number of samples in each batch. + * @param base_rowid The base row index for the current batch. + */ + RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid); ~RowPartitioner(); RowPartitioner(const RowPartitioner&) = delete; RowPartitioner& operator=(const RowPartitioner&) = delete; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index aa4f8fa27..366cf3aad 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -251,7 +251,8 @@ struct GPUHistMakerDevice { quantiser = std::make_unique(ctx_, this->gpair, dmat->Info()); row_partitioner.reset(); // Release the device memory first before reallocating - row_partitioner = std::make_unique(ctx_->Device(), sample.sample_rows); + CHECK_EQ(page->base_rowid, 0); + row_partitioner = std::make_unique(ctx_, sample.sample_rows, page->base_rowid); // Init histogram hist.Init(ctx_->Device(), page->Cuts().TotalBins()); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 3b9e6103a..d11284466 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -2,13 +2,15 @@ * Copyright 2020-2024, XGBoost Contributors */ #include +#include // for Context -#include +#include // for unique_ptr +#include // for vector #include "../../../../src/tree/gpu_hist/histogram.cuh" -#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" -#include "../../../../src/tree/param.h" // TrainParam -#include "../../categorical_helpers.h" +#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" // for RowPartitioner +#include "../../../../src/tree/param.h" // for TrainParam +#include "../../categorical_helpers.h" // for OneHotEncodeFeature #include "../../helpers.h" namespace xgboost::tree { @@ -24,7 +26,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) for (auto const& batch : matrix->GetBatches(&ctx, batch_param)) { auto* page = batch.Impl(); - tree::RowPartitioner row_partitioner(ctx.Device(), kRows); + tree::RowPartitioner row_partitioner{&ctx, kRows, page->base_rowid}; auto ridx = row_partitioner.GetRows(0); bst_bin_t num_bins = kBins * kCols; @@ -129,7 +131,7 @@ void TestGPUHistogramCategorical(size_t num_categories) { auto cat_m = GetDMatrixFromData(x, kRows, 1); cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()}; - tree::RowPartitioner row_partitioner(ctx.Device(), kRows); + tree::RowPartitioner row_partitioner{&ctx, kRows, 0}; auto ridx = row_partitioner.GetRows(0); dh::device_vector cat_hist(num_categories); auto gpair = GenerateRandomGradients(kRows, 0, 2); @@ -262,4 +264,105 @@ TEST(Histogram, Quantiser) { ASSERT_EQ(gh.GetHess(), 1.0); } } +namespace { +class HistogramExternalMemoryTest : public ::testing::TestWithParam> { + public: + void Run(float sparsity, bool force_global) { + bst_idx_t n_samples{512}, n_features{12}, n_batches{3}; + std::vector> partitioners; + auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity} + .Batches(n_batches) + .GenerateSparsePageDMatrix("cache", true); + bst_bin_t n_bins = 16; + BatchParam p{n_bins, TrainParam::DftSparseThreshold()}; + auto ctx = MakeCUDACtx(0); + + std::unique_ptr fg; + dh::device_vector single_hist; + dh::device_vector multi_hist; + + auto gpair = GenerateRandomGradients(n_samples); + gpair.SetDevice(ctx.Device()); + auto quantiser = GradientQuantiser{&ctx, gpair.ConstDeviceSpan(), p_fmat->Info()}; + std::shared_ptr cuts; + + { + /** + * Multi page. + */ + std::int32_t k{0}; + for (auto const& page : p_fmat->GetBatches(&ctx, p)) { + auto impl = page.Impl(); + if (k == 0) { + // Initialization + auto d_matrix = impl->GetDeviceAccessor(ctx.Device()); + fg = std::make_unique(impl->Cuts()); + auto init = GradientPairInt64{0, 0}; + multi_hist = decltype(multi_hist)(impl->Cuts().TotalBins(), init); + single_hist = decltype(single_hist)(impl->Cuts().TotalBins(), init); + cuts = std::make_shared(impl->Cuts()); + } + + partitioners.emplace_back( + std::make_unique(&ctx, impl->Size(), impl->base_rowid)); + + auto ridx = partitioners.at(k)->GetRows(0); + auto d_histogram = dh::ToSpan(multi_hist); + DeviceHistogramBuilder builder; + builder.Reset(&ctx, fg->DeviceAccessor(ctx.Device()), force_global); + builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()), + fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx, + d_histogram, quantiser); + ++k; + } + ASSERT_EQ(k, n_batches); + } + + { + /** + * Single page. + */ + RowPartitioner partitioner{&ctx, p_fmat->Info().num_row_, 0}; + SparsePage concat; + std::vector hess(p_fmat->Info().num_row_, 1.0f); + for (auto const& page : p_fmat->GetBatches()) { + concat.Push(page); + } + EllpackPageImpl page{ + ctx.Device(), cuts, concat, p_fmat->IsDense(), p_fmat->Info().num_col_, {}}; + auto ridx = partitioner.GetRows(0); + auto d_histogram = dh::ToSpan(single_hist); + DeviceHistogramBuilder builder; + builder.Reset(&ctx, fg->DeviceAccessor(ctx.Device()), force_global); + builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()), + fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx, + d_histogram, quantiser); + } + + std::vector h_single(single_hist.size()); + thrust::copy(single_hist.begin(), single_hist.end(), h_single.begin()); + std::vector h_multi(multi_hist.size()); + thrust::copy(multi_hist.begin(), multi_hist.end(), h_multi.begin()); + + for (std::size_t i = 0; i < single_hist.size(); ++i) { + ASSERT_EQ(h_single[i].GetQuantisedGrad(), h_multi[i].GetQuantisedGrad()); + ASSERT_EQ(h_single[i].GetQuantisedHess(), h_multi[i].GetQuantisedHess()); + } + } +}; +} // namespace + +TEST_P(HistogramExternalMemoryTest, ExternalMemory) { + std::apply(&HistogramExternalMemoryTest::Run, std::tuple_cat(std::make_tuple(this), GetParam())); +} + +INSTANTIATE_TEST_SUITE_P(Histogram, HistogramExternalMemoryTest, ::testing::ValuesIn([]() { + std::vector> params; + for (auto global : {true, false}) { + for (auto sparsity : {0.0f, 0.2f, 0.8f}) { + params.emplace_back(sparsity, global); + } + } + return params; + }())); } // namespace xgboost::tree diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index 14ea6fd70..cf0d505d1 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -1,25 +1,22 @@ -/*! - * Copyright 2019-2022 by XGBoost Contributors +/** + * Copyright 2019-2024, XGBoost Contributors */ #include #include -#include -#include -#include -#include +#include // for size_t +#include // for uint32_t +#include // for vector #include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../helpers.h" #include "xgboost/base.h" -#include "xgboost/context.h" -#include "xgboost/task.h" -#include "xgboost/tree_model.h" namespace xgboost::tree { void TestUpdatePositionBatch() { const int kNumRows = 10; - RowPartitioner rp(FstCU(), kNumRows); + auto ctx = MakeCUDACtx(0); + RowPartitioner rp{&ctx, kNumRows, 0}; auto rows = rp.GetRowsHost(0); EXPECT_EQ(rows.size(), kNumRows); for (auto i = 0ull; i < kNumRows; i++) { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 1c156563c..200fb39fb 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -106,7 +106,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { gpair.SetDevice(ctx.Device()); thrust::host_vector h_gidx_buffer(page->gidx_buffer.HostVector()); - maker.row_partitioner = std::make_unique(ctx.Device(), kNRows); + maker.row_partitioner = std::make_unique(&ctx, kNRows, 0); maker.hist.Init(ctx.Device(), page->Cuts().TotalBins()); maker.hist.AllocateHistograms({0});