From a19bbc9be5da139bf770e8a03ede7536f8d193a3 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 23 Jul 2024 03:48:03 +0800 Subject: [PATCH] Avoid caching allocator for large allocations. (#10582) --- src/common/hist_util.cu | 2 +- src/tree/gpu_hist/row_partitioner.cu | 16 +++- src/tree/gpu_hist/row_partitioner.cuh | 91 +++++++++++-------- src/tree/updater_gpu_hist.cu | 6 +- tests/cpp/tree/gpu_hist/test_histogram.cu | 14 ++- .../cpp/tree/gpu_hist/test_row_partitioner.cu | 3 +- tests/cpp/tree/test_gpu_hist.cu | 3 +- 7 files changed, 80 insertions(+), 55 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 39f310ebb..3bf4047e2 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -227,7 +227,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c }); detail::SortByWeight(&entry_weight, &sorted_entries); } else { - thrust::sort(cuctx->CTP(), sorted_entries.begin(), sorted_entries.end(), + thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); } diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index f66fac489..61e42d909 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -10,14 +10,20 @@ #include "row_partitioner.cuh" namespace xgboost::tree { -RowPartitioner::RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid) - : device_idx_(ctx->Device()), ridx_(n_samples), ridx_tmp_(n_samples) { - dh::safe_cuda(cudaSetDevice(device_idx_.ordinal)); - ridx_segments_.emplace_back(NodePositionInfo{Segment(0, n_samples)}); +void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid) { + ridx_segments_.clear(); + ridx_.resize(n_samples); + ridx_tmp_.resize(n_samples); + tmp_.clear(); + + CHECK_LE(n_samples, std::numeric_limits::max()); + ridx_segments_.emplace_back( + NodePositionInfo{Segment{0, static_cast(n_samples)}}); + thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid); } -RowPartitioner::~RowPartitioner() { dh::safe_cuda(cudaSetDevice(device_idx_.ordinal)); } +RowPartitioner::~RowPartitioner() = default; common::Span RowPartitioner::GetRows(bst_node_t nidx) { auto segment = ridx_segments_.at(nidx).segment; diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 636de54e6..a811155d4 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -7,25 +7,34 @@ #include // for make_transform_output_iterator #include // for max +#include // for size_t +#include // for int32_t, uint32_t #include // for vector #include "../../common/device_helpers.cuh" // for MakeTransformIterator #include "xgboost/base.h" // for bst_idx_t #include "xgboost/context.h" // for Context +#include "xgboost/span.h" // for Span -namespace xgboost { -namespace tree { +namespace xgboost::tree { +namespace cuda_impl { +using RowIndexT = std::uint32_t; +} -/** \brief Used to demarcate a contiguous set of row indices associated with - * some tree node. */ +/** + * @brief Used to demarcate a contiguous set of row indices associated with some tree + * node. + */ struct Segment { - bst_uint begin{0}; - bst_uint end{0}; + cuda_impl::RowIndexT begin{0}; + cuda_impl::RowIndexT end{0}; Segment() = default; - Segment(bst_uint begin, bst_uint end) : begin(begin), end(end) { CHECK_GE(end, begin); } - __host__ __device__ size_t Size() const { return end - begin; } + Segment(cuda_impl::RowIndexT begin, cuda_impl::RowIndexT end) : begin(begin), end(end) { + CHECK_GE(end, begin); + } + __host__ __device__ bst_idx_t Size() const { return end - begin; } }; // TODO(Rory): Can be larger. To be tuned alongside other batch operations. @@ -39,7 +48,7 @@ struct PerNodeData { template __device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx, int* batch_idx, std::size_t* item_idx) { - bst_uint sum = 0; + cuda_impl::RowIndexT sum = 0; for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) { if (sum + batch_info[i].segment.Size() > global_thread_idx) { *batch_idx = i; @@ -65,10 +74,10 @@ __global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel( // We can scan over this tuple, where the scan gives us information on how to partition inputs // according to the flag struct IndexFlagTuple { - bst_uint idx; // The location of the item we are working on in ridx_ - bst_uint flag_scan; // This gets populated after scanning - int batch_idx; // Which node in the batch does this item belong to - bool flag; // Result of op (is this item going left?) + cuda_impl::RowIndexT idx; // The location of the item we are working on in ridx_ + cuda_impl::RowIndexT flag_scan; // This gets populated after scanning + std::int32_t batch_idx; // Which node in the batch does this item belong to + bool flag; // Result of op (is this item going left?) }; struct IndexFlagOp { @@ -86,18 +95,18 @@ struct IndexFlagOp { template struct WriteResultsFunctor { dh::LDGIterator> batch_info; - const bst_uint* ridx_in; - bst_uint* ridx_out; - bst_uint* counts; + cuda_impl::RowIndexT const* ridx_in; + cuda_impl::RowIndexT* ridx_out; + cuda_impl::RowIndexT* counts; __device__ IndexFlagTuple operator()(const IndexFlagTuple& x) { std::size_t scatter_address; const Segment& segment = batch_info[x.batch_idx].segment; if (x.flag) { - bst_uint num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan + cuda_impl::RowIndexT num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan scatter_address = segment.begin + num_previous_flagged; } else { - bst_uint num_previous_unflagged = (x.idx - segment.begin) - x.flag_scan; + cuda_impl::RowIndexT num_previous_unflagged = (x.idx - segment.begin) - x.flag_scan; scatter_address = segment.end - num_previous_unflagged - 1; } ridx_out[scatter_address] = ridx_in[x.idx]; @@ -115,7 +124,7 @@ struct WriteResultsFunctor { template void SortPositionBatch(common::Span> d_batch_info, common::Span ridx, common::Span ridx_tmp, - common::Span d_counts, std::size_t total_rows, OpT op, + common::Span d_counts, std::size_t total_rows, OpT op, dh::device_vector* tmp) { dh::LDGIterator> batch_info_itr(d_batch_info.data()); WriteResultsFunctor write_results{batch_info_itr, ridx.data(), ridx_tmp.data(), @@ -130,7 +139,7 @@ void SortPositionBatch(common::Span> d_batch_info, std::size_t item_idx; AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx); auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data); - return IndexFlagTuple{static_cast(item_idx), op_res, batch_idx, op_res}; + return IndexFlagTuple{static_cast(item_idx), op_res, batch_idx, op_res}; }); size_t temp_bytes = 0; if (tmp->empty()) { @@ -195,29 +204,31 @@ __global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel( * partition training rows into different leaf nodes. */ class RowPartitioner { public: - using RowIndexT = bst_uint; + using RowIndexT = cuda_impl::RowIndexT; static constexpr bst_node_t kIgnoredTreePosition = -1; private: - DeviceOrd device_idx_; - /*! \brief In here if you want to find the rows belong to a node nid, first you need to - * get the indices segment from ridx_segments[nid], then get the row index that - * represents position of row in input data X. `RowPartitioner::GetRows` would be a - * good starting place to get a sense what are these vector storing. + /** + * In here if you want to find the rows belong to a node nid, first you need to get the + * indices segment from ridx_segments[nid], then get the row index that represents + * position of row in input data X. `RowPartitioner::GetRows` would be a good starting + * place to get a sense what are these vector storing. * * node id -> segment -> indices of rows belonging to node */ - /*! \brief Range of row index for each node, pointers into ridx below. */ + /** @brief Range of row index for each node, pointers into ridx below. */ std::vector ridx_segments_; - /*! \brief mapping for node id -> rows. + /** + * @brief mapping for node id -> rows. + * * This looks like: * node id | 1 | 2 | * rows idx | 3, 5, 1 | 13, 31 | */ - dh::TemporaryArray ridx_; + dh::DeviceUVector ridx_; // Staging area for sorting ridx - dh::TemporaryArray ridx_tmp_; + dh::DeviceUVector ridx_tmp_; dh::device_vector tmp_; dh::PinnedMemory pinned_; dh::PinnedMemory pinned2_; @@ -228,7 +239,9 @@ class RowPartitioner { * @param n_samples The number of samples in each batch. * @param base_rowid The base row index for the current batch. */ - RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid); + RowPartitioner() = default; + void Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid); + ~RowPartitioner(); RowPartitioner(const RowPartitioner&) = delete; RowPartitioner& operator=(const RowPartitioner&) = delete; @@ -285,8 +298,8 @@ class RowPartitioner { cudaMemcpyDefault)); // Temporary arrays - auto h_counts = pinned_.GetSpan(nidx.size(), 0); - dh::TemporaryArray d_counts(nidx.size(), 0); + auto h_counts = pinned_.GetSpan(nidx.size(), 0); + dh::TemporaryArray d_counts(nidx.size(), 0); // Partition the rows according to the operator SortPositionBatch( @@ -299,7 +312,7 @@ class RowPartitioner { dh::DefaultStream().Sync(); // Update segments - for (size_t i = 0; i < nidx.size(); i++) { + for (std::size_t i = 0; i < nidx.size(); i++) { auto segment = ridx_segments_.at(nidx[i]).segment; auto left_count = h_counts[i]; CHECK_LE(left_count, segment.Size()); @@ -336,11 +349,9 @@ class RowPartitioner { constexpr int kBlockSize = 512; const int kItemsThread = 8; const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread); - common::Span d_ridx(ridx_.data().get(), ridx_.size()); - FinalisePositionKernel<<>>( - dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op); + common::Span d_ridx{ridx_.data(), ridx_.size()}; + FinalisePositionKernel + <<>>(dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op); } }; - -}; // namespace tree -}; // namespace xgboost +}; // namespace xgboost::tree diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 83f84ec1f..477a7b08a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -145,9 +145,11 @@ struct GPUHistMakerDevice { quantiser = std::make_unique(ctx_, this->gpair, dmat->Info()); - row_partitioner.reset(); // Release the device memory first before reallocating + if (!row_partitioner) { + row_partitioner = std::make_unique(); + } + row_partitioner->Reset(ctx_, sample.sample_rows, page->base_rowid); CHECK_EQ(page->base_rowid, 0); - row_partitioner = std::make_unique(ctx_, sample.sample_rows, page->base_rowid); // Init histogram hist.Init(ctx_->Device(), page->Cuts().TotalBins()); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index c9320f616..fe2544215 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -66,7 +66,8 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) for (auto const& batch : matrix->GetBatches(&ctx, batch_param)) { auto* page = batch.Impl(); - tree::RowPartitioner row_partitioner{&ctx, kRows, page->base_rowid}; + tree::RowPartitioner row_partitioner; + row_partitioner.Reset(&ctx, kRows, page->base_rowid); auto ridx = row_partitioner.GetRows(0); bst_bin_t num_bins = kBins * kCols; @@ -171,7 +172,8 @@ void TestGPUHistogramCategorical(size_t num_categories) { auto cat_m = GetDMatrixFromData(x, kRows, 1); cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()}; - tree::RowPartitioner row_partitioner{&ctx, kRows, 0}; + tree::RowPartitioner row_partitioner; + row_partitioner.Reset(&ctx, kRows, 0); auto ridx = row_partitioner.GetRows(0); dh::device_vector cat_hist(num_categories); auto gpair = GenerateRandomGradients(kRows, 0, 2); @@ -343,8 +345,8 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam(impl->Cuts()); } - partitioners.emplace_back( - std::make_unique(&ctx, impl->Size(), impl->base_rowid)); + partitioners.emplace_back(std::make_unique()); + partitioners.back()->Reset(&ctx, impl->Size(), impl->base_rowid); auto ridx = partitioners.at(k)->GetRows(0); auto d_histogram = dh::ToSpan(multi_hist); @@ -362,7 +364,9 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParamInfo().num_row_, 0}; + RowPartitioner partitioner; + partitioner.Reset(&ctx, p_fmat->Info().num_row_, 0); + SparsePage concat; std::vector hess(p_fmat->Info().num_row_, 1.0f); for (auto const& page : p_fmat->GetBatches()) { diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index cf0d505d1..f891d73f5 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -16,7 +16,8 @@ namespace xgboost::tree { void TestUpdatePositionBatch() { const int kNumRows = 10; auto ctx = MakeCUDACtx(0); - RowPartitioner rp{&ctx, kNumRows, 0}; + RowPartitioner rp; + rp.Reset(&ctx, kNumRows, 0); auto rows = rp.GetRowsHost(0); EXPECT_EQ(rows.size(), kNumRows); for (auto i = 0ull; i < kNumRows; i++) { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 728fb62c4..631149261 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -64,7 +64,8 @@ void TestBuildHist(bool use_shared_memory_histograms) { } gpair.SetDevice(ctx.Device()); - maker.row_partitioner = std::make_unique(&ctx, kNRows, 0); + maker.row_partitioner = std::make_unique(); + maker.row_partitioner->Reset(&ctx, kNRows, 0); maker.hist.Init(ctx.Device(), page->Cuts().TotalBins()); maker.hist.AllocateHistograms(&ctx, {0});