From 7af0946ac18020b756fd22015563b23eb487f7a7 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 14 Nov 2018 19:33:29 +1300 Subject: [PATCH] Improve update position function for gpu_hist (#3895) --- src/common/device_helpers.cuh | 70 ++++++++++++++++- src/tree/updater_gpu_hist.cu | 130 +++++++++++++++++++------------- tests/cpp/tree/test_gpu_hist.cu | 41 +++++++++- 3 files changed, 184 insertions(+), 57 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 550aed407..d0556a30c 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -766,7 +766,8 @@ typename std::iterator_traits::value_type SumReduction( dh::CubMemory &tmp_mem, T in, int nVals) { using ValueT = typename std::iterator_traits::value_type; size_t tmpSize; - dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals)); + ValueT *dummy_out = nullptr; + dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, dummy_out, nVals)); // Allocate small extra memory for the return value tmp_mem.LazyAllocate(tmpSize + sizeof(ValueT)); auto ptr = reinterpret_cast(tmp_mem.d_temp_storage) + 1; @@ -1074,4 +1075,71 @@ xgboost::common::Span ToSpan(thrust::device_vector& vec, using IndexT = typename xgboost::common::Span::index_type; return ToSpan(vec, static_cast(offset), static_cast(size)); } + +template +class LauncherItr { +public: + int idx; + FunctionT f; + XGBOOST_DEVICE LauncherItr() : idx(0) {} + XGBOOST_DEVICE LauncherItr(int idx, FunctionT f) : idx(idx), f(f) {} + XGBOOST_DEVICE LauncherItr &operator=(int output) { + f(idx, output); + return *this; + } +}; + +/** + * \brief Thrust compatible iterator type - discards algorithm output and launches device lambda + * with the index of the output and the algorithm output as arguments. + * + * \author Rory + * \date 7/9/2017 + * + * \tparam FunctionT Type of the function t. + */ +template +class DiscardLambdaItr { +public: + // Required iterator traits + using self_type = DiscardLambdaItr; // NOLINT + using difference_type = ptrdiff_t; // NOLINT + using value_type = void; // NOLINT + using pointer = value_type *; // NOLINT + using reference = LauncherItr; // NOLINT + using iterator_category = typename thrust::detail::iterator_facade_category< + thrust::any_system_tag, thrust::random_access_traversal_tag, value_type, + reference>::type; // NOLINT +private: + difference_type offset_; + FunctionT f_; +public: + XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {} + XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f) + : offset_(offset), f_(f) {} + XGBOOST_DEVICE self_type operator+(const int &b) const { + return DiscardLambdaItr(offset_ + b, f_); + } + XGBOOST_DEVICE self_type operator++() { + offset_++; + return *this; + } + XGBOOST_DEVICE self_type operator++(int) { + self_type retval = *this; + offset_++; + return retval; + } + XGBOOST_DEVICE self_type &operator+=(const int &b) { + offset_ += b; + return *this; + } + XGBOOST_DEVICE reference operator*() const { + return LauncherItr(offset_, f_); + } + XGBOOST_DEVICE reference operator[](int idx) { + self_type offset = (*this) + idx; + return *offset; + } +}; + } // namespace dh diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 4177ee96e..99d795063 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -380,6 +380,53 @@ struct Segment { size_t Size() const { return end - begin; } }; +/** \brief Returns a one if the left node index is encountered, otherwise return + * zero. */ +struct IndicateLeftTransform { + int left_nidx; + explicit IndicateLeftTransform(int left_nidx) : left_nidx(left_nidx) {} + __host__ __device__ __forceinline__ int operator()(const int& x) const { + return x == left_nidx ? 1 : 0; + } +}; + +/** + * \brief Optimised routine for sorting key value pairs into left and right + * segments. Based on a single pass of exclusive scan, uses iterators to + * redirect inputs and outputs. + */ +void SortPosition(dh::CubMemory* temp_memory, common::Span position, + common::Span position_out, common::Span ridx, + common::Span ridx_out, int left_nidx, + int right_nidx, int64_t left_count) { + auto d_position_out = position_out.data(); + auto d_position_in = position.data(); + auto d_ridx_out = ridx_out.data(); + auto d_ridx_in = ridx.data(); + auto write_results = [=] __device__(size_t idx, int ex_scan_result) { + int scatter_address; + if (d_position_in[idx] == left_nidx) { + scatter_address = ex_scan_result; + } else { + scatter_address = (idx - ex_scan_result) + left_count; + } + d_position_out[scatter_address] = d_position_in[idx]; + d_ridx_out[scatter_address] = d_ridx_in[idx]; + }; // NOLINT + + IndicateLeftTransform conversion_op(left_nidx); + cub::TransformInputIterator in_itr( + d_position_in, conversion_op); + dh::DiscardLambdaItr out_itr(write_results); + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr, + position.size()); + temp_memory->LazyAllocate(temp_storage_bytes); + cub::DeviceScan::ExclusiveSum(temp_memory->d_temp_storage, + temp_memory->temp_storage_bytes, in_itr, + out_itr, position.size()); +} + struct DeviceShard; struct GPUHistBuilderBase { @@ -440,26 +487,22 @@ struct DeviceShard { TrainParam param; bool prediction_cache_initialised; - int64_t* tmp_pinned; // Small amount of staging memory - dh::CubMemory temp_memory; std::unique_ptr hist_builder; // TODO(canonizer): do add support multi-batch DMatrix here - DeviceShard(int device_id, - bst_uint row_begin, bst_uint row_end, TrainParam _param) : - device_id_(device_id), - row_begin_idx(row_begin), - row_end_idx(row_end), - row_stride(0), - n_rows(row_end - row_begin), - n_bins(0), - null_gidx_value(0), - param(_param), - prediction_cache_initialised(false), - tmp_pinned(nullptr) - {} + DeviceShard(int device_id, bst_uint row_begin, bst_uint row_end, + TrainParam _param) + : device_id_(device_id), + row_begin_idx(row_begin), + row_end_idx(row_end), + row_stride(0), + n_rows(row_end - row_begin), + n_bins(0), + null_gidx_value(0), + param(_param), + prediction_cache_initialised(false) {} /* Init row_ptrs and row_stride */ void InitRowPtrs(const SparsePage& row_batch) { @@ -495,7 +538,6 @@ struct DeviceShard { void CreateHistIndices(const SparsePage& row_batch); ~DeviceShard() { - dh::safe_cuda(cudaFreeHost(tmp_pinned)); } // Reset values for each update iteration @@ -587,29 +629,18 @@ struct DeviceShard { hist.HistogramExists(nidx_parent); } - /*! \brief Count how many rows are assigned to left node. */ - __device__ void CountLeft(int64_t* d_count, int val, int left_nidx) { - unsigned ballot = __ballot(val == left_nidx); - if (threadIdx.x % 32 == 0) { - atomicAdd(reinterpret_cast(d_count), // NOLINT - static_cast(__popc(ballot))); // NOLINT - } - } - void UpdatePosition(int nidx, int left_nidx, int right_nidx, int fidx, int64_t split_gidx, bool default_dir_left, bool is_dense, int fidx_begin, // cut.row_ptr[fidx] int fidx_end) { // cut.row_ptr[fidx + 1] dh::safe_cuda(cudaSetDevice(device_id_)); - auto d_left_count = temp_memory.GetSpan(1); - dh::safe_cuda(cudaMemset(d_left_count.data(), 0, sizeof(int64_t))); Segment segment = ridx_segments[nidx]; bst_uint* d_ridx = ridx.Current(); int* d_position = position.Current(); common::CompressedIterator d_gidx = gidx; size_t row_stride = this->row_stride; // Launch 1 thread for each row - dh::LaunchN<1, 512>( + dh::LaunchN<1, 128>( device_id_, segment.Size(), [=] __device__(bst_uint idx) { idx += segment.begin; bst_uint ridx = d_ridx[idx]; @@ -634,13 +665,16 @@ struct DeviceShard { position = default_dir_left ? left_nidx : right_nidx; } - CountLeft(d_left_count.data(), position, left_nidx); d_position[idx] = position; }); - dh::safe_cuda(cudaMemcpy(tmp_pinned, d_left_count.data(), sizeof(int64_t), - cudaMemcpyDeviceToHost)); - auto left_count = *tmp_pinned; - SortPosition(segment, left_nidx, right_nidx); + IndicateLeftTransform conversion_op(left_nidx); + cub::TransformInputIterator left_itr( + d_position + segment.begin, conversion_op); + int left_count = dh::SumReduction(temp_memory, left_itr, segment.Size()); + CHECK_LE(left_count, segment.Size()); + CHECK_GE(left_count, 0); + + SortPositionAndCopy(segment, left_nidx, right_nidx, left_count); ridx_segments[left_nidx] = Segment(segment.begin, segment.begin + left_count); @@ -649,25 +683,15 @@ struct DeviceShard { } /*! \brief Sort row indices according to position. */ - void SortPosition(const Segment& segment, int left_nidx, int right_nidx) { - int min_bits = 0; - int max_bits = static_cast( - std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1))); - - size_t temp_storage_bytes = 0; - cub::DeviceRadixSort::SortPairs( - nullptr, temp_storage_bytes, - position.Current() + segment.begin, position.other() + segment.begin, - ridx.Current() + segment.begin, ridx.other() + segment.begin, - segment.Size(), min_bits, max_bits); - - temp_memory.LazyAllocate(temp_storage_bytes); - - cub::DeviceRadixSort::SortPairs( - temp_memory.d_temp_storage, temp_memory.temp_storage_bytes, - position.Current() + segment.begin, position.other() + segment.begin, - ridx.Current() + segment.begin, ridx.other() + segment.begin, - segment.Size(), min_bits, max_bits); + void SortPositionAndCopy(const Segment& segment, int left_nidx, int right_nidx, + size_t left_count) { + SortPosition( + &temp_memory, + common::Span(position.Current() + segment.begin, segment.Size()), + common::Span(position.other() + segment.begin, segment.Size()), + common::Span(ridx.Current() + segment.begin, segment.Size()), + common::Span(ridx.other() + segment.begin, segment.Size()), + left_nidx, right_nidx, left_count); // Copy back key dh::safe_cuda(cudaMemcpy( position.Current() + segment.begin, position.other() + segment.begin, @@ -823,8 +847,6 @@ inline void DeviceShard::InitCompressedData( // Init histogram hist.Init(device_id_, hmat.row_ptr.back()); - - dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t))); } inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 600e83243..947e8b11c 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -327,8 +327,6 @@ TEST(GpuHist, ApplySplit) { shard->row_stride = n_cols; thrust::sequence(shard->ridx.CurrentDVec().tbegin(), shard->ridx.CurrentDVec().tend()); - // Free inside DeviceShard - dh::safe_cuda(cudaMallocHost(&(shard->tmp_pinned), sizeof(int64_t))); // Initialize GPUHistMaker hist_maker.param_ = param; RegTree tree; @@ -389,5 +387,44 @@ TEST(GpuHist, ApplySplit) { ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16); } +void TestSortPosition(const std::vector& position_in, int left_idx, + int right_idx) { + int left_count = std::count(position_in.begin(), position_in.end(), left_idx); + thrust::device_vector position = position_in; + thrust::device_vector position_out(position.size()); + + thrust::device_vector ridx(position.size()); + thrust::sequence(ridx.begin(), ridx.end()); + thrust::device_vector ridx_out(ridx.size()); + dh::CubMemory tmp; + SortPosition( + &tmp, common::Span(position.data().get(), position.size()), + common::Span(position_out.data().get(), position_out.size()), + common::Span(ridx.data().get(), ridx.size()), + common::Span(ridx_out.data().get(), ridx_out.size()), left_idx, + right_idx, left_count); + thrust::host_vector position_result = position_out; + thrust::host_vector ridx_result = ridx_out; + + // Check position is sorted + EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end())); + // Check row indices are sorted inside left and right segment + EXPECT_TRUE( + std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count)); + EXPECT_TRUE( + std::is_sorted(ridx_result.begin() + left_count, ridx_result.end())); + + // Check key value pairs are the same + for (auto i = 0ull; i < ridx_result.size(); i++) { + EXPECT_EQ(position_result[i], position_in[ridx_result[i]]); + } +} + +TEST(GpuHist, SortPosition) { + TestSortPosition({1, 2, 1, 2, 1}, 1, 2); + TestSortPosition({1, 1, 1, 1}, 1, 2); + TestSortPosition({2, 2, 2, 2}, 1, 2); + TestSortPosition({1, 2, 1, 2, 3}, 1, 2); +} } // namespace tree } // namespace xgboost