From 4230dcb6149a3125a1141f2bf4b2a5513e9ec8b7 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Thu, 18 Mar 2021 13:56:10 -0700 Subject: [PATCH] Re-introduce double buffer in UpdatePosition, to fix perf regression in gpu_hist (#6757) * Revert "gpu_hist performance tweaks (#5707)" This reverts commit f779980f7ea7f6f07e86229b8e78144e8a74e6b3. * Address reviewer's comment * Fix build error --- src/common/device_helpers.cuh | 30 +++++++++++++++++++++++ src/tree/gpu_hist/row_partitioner.cu | 34 ++++++++++++++------------- src/tree/gpu_hist/row_partitioner.cuh | 20 ++++++++++++---- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index e6024f04f..1da0a3be6 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -549,6 +549,36 @@ class TemporaryArray { size_t size_; }; +/** + * \brief A double buffer, useful for algorithms like sort. + */ +template +class DoubleBuffer { + public: + cub::DoubleBuffer buff; + xgboost::common::Span a, b; + DoubleBuffer() = default; + template + DoubleBuffer(VectorT *v1, VectorT *v2) { + a = xgboost::common::Span(v1->data().get(), v1->size()); + b = xgboost::common::Span(v2->data().get(), v2->size()); + buff = cub::DoubleBuffer(a.data(), b.data()); + } + + size_t Size() const { + CHECK_EQ(a.size(), b.size()); + return a.size(); + } + cub::DoubleBuffer &CubBuffer() { return buff; } + + T *Current() { return buff.Current(); } + xgboost::common::Span CurrentSpan() { + return xgboost::common::Span{buff.Current(), Size()}; + } + + T *Other() { return buff.Alternate(); } +}; + /** * \brief Copies device span to std::vector. * diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index 935b30676..cf9e3f769 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -103,13 +103,17 @@ void Reset(int device_idx, common::Span ridx, } RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) - : device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows) { + : device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows), + ridx_b_(num_rows), position_b_(num_rows) { dh::safe_cuda(cudaSetDevice(device_idx_)); - Reset(device_idx, dh::ToSpan(ridx_a_), dh::ToSpan(position_a_)); + ridx_ = dh::DoubleBuffer{&ridx_a_, &ridx_b_}; + position_ = dh::DoubleBuffer{&position_a_, &position_b_}; + ridx_segments_.emplace_back(Segment(0, num_rows)); + + Reset(device_idx, ridx_.CurrentSpan(), position_.CurrentSpan()); left_counts_.resize(256); thrust::fill(left_counts_.begin(), left_counts_.end(), 0); streams_.resize(2); - ridx_segments_.emplace_back(Segment(0, num_rows)); for (auto& stream : streams_) { dh::safe_cuda(cudaStreamCreate(&stream)); } @@ -129,15 +133,15 @@ common::Span RowPartitioner::GetRows( if (segment.Size() == 0) { return common::Span(); } - return dh::ToSpan(ridx_a_).subspan(segment.begin, segment.Size()); + return ridx_.CurrentSpan().subspan(segment.begin, segment.Size()); } common::Span RowPartitioner::GetRows() { - return dh::ToSpan(ridx_a_); + return ridx_.CurrentSpan(); } common::Span RowPartitioner::GetPosition() { - return dh::ToSpan(position_a_); + return position_.CurrentSpan(); } std::vector RowPartitioner::GetRowsHost( bst_node_t nidx) { @@ -159,25 +163,23 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment, bst_node_t right_nidx, int64_t* d_left_count, cudaStream_t stream) { - dh::TemporaryArray position_temp(position_a_.size()); - dh::TemporaryArray ridx_temp(ridx_a_.size()); SortPosition( // position_in - common::Span(position_a_.data().get() + segment.begin, + common::Span(position_.Current() + segment.begin, segment.Size()), // position_out - common::Span(position_temp.data().get() + segment.begin, + common::Span(position_.Other() + segment.begin, segment.Size()), // row index in - common::Span(ridx_a_.data().get() + segment.begin, segment.Size()), + common::Span(ridx_.Current() + segment.begin, segment.Size()), // row index out - common::Span(ridx_temp.data().get() + segment.begin, segment.Size()), + common::Span(ridx_.Other() + segment.begin, segment.Size()), left_nidx, right_nidx, d_left_count, stream); // Copy back key/value - const auto d_position_current = position_a_.data().get() + segment.begin; - const auto d_position_other = position_temp.data().get() + segment.begin; - const auto d_ridx_current = ridx_a_.data().get() + segment.begin; - const auto d_ridx_other = ridx_temp.data().get() + segment.begin; + const auto d_position_current = position_.Current() + segment.begin; + const auto d_position_other = position_.Other() + segment.begin; + const auto d_ridx_current = ridx_.Current() + segment.begin; + const auto d_ridx_other = ridx_.Other() + segment.begin; dh::LaunchN(device_idx_, segment.Size(), stream, [=] __device__(size_t idx) { d_position_current[idx] = d_position_other[idx]; d_ridx_current[idx] = d_ridx_other[idx]; diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index c897b4bbf..96f327fb9 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -47,7 +47,17 @@ class RowPartitioner { /*! \brief Range of row index for each node, pointers into ridx below. */ std::vector ridx_segments_; dh::TemporaryArray ridx_a_; + dh::TemporaryArray ridx_b_; dh::TemporaryArray position_a_; + dh::TemporaryArray position_b_; + /*! \brief mapping for node id -> rows. + * This looks like: + * node id | 1 | 2 | + * rows idx | 3, 5, 1 | 13, 31 | + */ + dh::DoubleBuffer ridx_; + /*! \brief mapping for row -> node id. */ + dh::DoubleBuffer position_; dh::caching_device_vector left_counts_; // Useful to keep a bunch of zeroed memory for sort position std::vector streams_; @@ -100,8 +110,8 @@ class RowPartitioner { void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx, bst_node_t right_nidx, UpdatePositionOpT op) { Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx - auto d_ridx = dh::ToSpan(ridx_a_); - auto d_position = dh::ToSpan(position_a_); + auto d_ridx = ridx_.CurrentSpan(); + auto d_position = position_.CurrentSpan(); if (left_counts_.size() <= nidx) { left_counts_.resize((nidx * 2) + 1); thrust::fill(left_counts_.begin(), left_counts_.end(), 0); @@ -148,9 +158,9 @@ class RowPartitioner { */ template void FinalisePosition(FinalisePositionOpT op) { - auto d_position = position_a_.data().get(); - const auto d_ridx = ridx_a_.data().get(); - dh::LaunchN(device_idx_, position_a_.size(), [=] __device__(size_t idx) { + auto d_position = position_.Current(); + const auto d_ridx = ridx_.Current(); + dh::LaunchN(device_idx_, position_.Size(), [=] __device__(size_t idx) { auto position = d_position[idx]; RowIndexT ridx = d_ridx[idx]; bst_node_t new_position = op(ridx, position);