From e268fb0093861679755413131058d16530654941 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 16 Apr 2020 21:41:16 +1200 Subject: [PATCH] Use thrust functions instead of custom functions (#5544) --- src/common/device_helpers.cuh | 161 +----------------------- src/data/device_adapter.cuh | 2 +- src/data/ellpack_page.cu | 4 +- src/objective/rank_obj.cu | 17 ++- src/tree/gpu_hist/row_partitioner.cu | 108 +++++++++------- tests/cpp/common/test_device_helpers.cu | 96 -------------- 6 files changed, 82 insertions(+), 306 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index e709a8133..54cf920a7 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -9,15 +9,14 @@ #include #include #include +#include -#include #include #include #include #include #include -#include #include #include #include @@ -28,7 +27,6 @@ #include "xgboost/span.h" #include "common.h" -#include "timer.h" #ifdef XGBOOST_USE_NCCL #include "nccl.h" @@ -132,94 +130,6 @@ DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, un static_cast(b) << (ibyte % (sizeof(unsigned int)) * 8)); } -namespace internal { - -// Items of size 'n' are sorted in an order determined by the Comparator -// If left is true, find the number of elements where 'comp(item, v)' returns true; -// 0 if nothing is true -// If left is false, find the number of elements where '!comp(item, v)' returns true; -// 0 if nothing is true -template > -XGBOOST_DEVICE __forceinline__ uint32_t -CountNumItemsImpl(bool left, const T * __restrict__ items, uint32_t n, T v, - const Comparator &comp = Comparator()) { - const T *items_begin = items; - uint32_t num_remaining = n; - const T *middle_item = nullptr; - uint32_t middle; - while (num_remaining > 0) { - middle_item = items_begin; - middle = num_remaining / 2; - middle_item += middle; - if ((left && comp(*middle_item, v)) || (!left && !comp(v, *middle_item))) { - items_begin = ++middle_item; - num_remaining -= middle + 1; - } else { - num_remaining = middle; - } - } - - return left ? items_begin - items : items + n - items_begin; -} - -} // namespace internal - -/*! - * \brief Find the strict upper bound for an element in a sorted array - * using binary search. - * \param items pointer to the first element of the sorted array - * \param n length of the sorted array - * \param v value for which to find the upper bound - * \param comp determines how the items are sorted ascending/descending order - should conform - * to ordering semantics - * \return the smallest index i that has a value > v, or n if none is larger when sorted ascendingly - * or, an index i with a value < v, or 0 if none is smaller when sorted descendingly -*/ -// Preserve existing default behavior of upper bound -template > -XGBOOST_DEVICE __forceinline__ uint32_t UpperBound(const T *__restrict__ items, - uint32_t n, - T v, - const Comp &comp = Comp()) { - if (std::is_same>::value || - std::is_same>::value) { - return n - internal::CountNumItemsImpl(false, items, n, v, comp); - } else { - static_assert(std::is_same>::value || - std::is_same>::value, - "Invalid comparator used in Upperbound - can only be thrust::greater/less"); - return std::numeric_limits::max(); // Simply to quiesce the compiler - } -} - -/*! - * \brief Find the strict lower bound for an element in a sorted array - * using binary search. - * \param items pointer to the first element of the sorted array - * \param n length of the sorted array - * \param v value for which to find the upper bound - * \param comp determines how the items are sorted ascending/descending order - should conform - * to ordering semantics - * \return the smallest index i that has a value >= v, or n if none is larger - * when sorted ascendingly - * or, an index i with a value <= v, or 0 if none is smaller when sorted descendingly -*/ -template > -XGBOOST_DEVICE __forceinline__ uint32_t LowerBound(const T *__restrict__ items, - uint32_t n, - T v, - const Comp &comp = Comp()) { - if (std::is_same>::value || - std::is_same>::value) { - return internal::CountNumItemsImpl(true, items, n, v, comp); - } else { - static_assert(std::is_same>::value || - std::is_same>::value, - "Invalid comparator used in LowerBound - can only be thrust::greater/less"); - return std::numeric_limits::max(); // Simply to quiesce the compiler - } -} - template __device__ xgboost::common::Range GridStrideRange(T begin, T end) { begin += blockDim.x * blockIdx.x + threadIdx.x; @@ -878,7 +788,8 @@ class SegmentSorter { const uint32_t *dgroups = dgroups_.data().get(); uint32_t ngroups = dgroups_.size(); auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) { - return dh::UpperBound(dgroups, ngroups, idx) - 1; + return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) - + dgroups - 1; }; // NOLINT thrust::transform(thrust::make_counting_iterator(static_cast(0)), @@ -1018,70 +929,4 @@ thrust::transform_iterator MakeTransformIterator( return thrust::transform_iterator(iter, func); } -template -class LauncherItr { -public: - int idx { 0 }; - FunctionT f; - XGBOOST_DEVICE LauncherItr() : idx(0) {} // NOLINT - XGBOOST_DEVICE LauncherItr(int idx, FunctionT f) : idx(idx), f(f) {} - XGBOOST_DEVICE LauncherItr &operator=(int output) { - f(idx, output); - return *this; - } -}; - -/** -* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda -* with the index of the output and the algorithm output as arguments. -* -* \author Rory -* \date 7/9/2017 -* -* \tparam FunctionT Type of the function t. -*/ -template -class DiscardLambdaItr { -public: - // Required iterator traits - using self_type = DiscardLambdaItr; // NOLINT - using difference_type = ptrdiff_t; // NOLINT - using value_type = void; // NOLINT - using pointer = value_type *; // NOLINT - using reference = LauncherItr; // NOLINT - using iterator_category = typename thrust::detail::iterator_facade_category< // NOLINT - thrust::any_system_tag, thrust::random_access_traversal_tag, value_type, - reference>::type; // NOLINT -private: - difference_type offset_; - FunctionT f_; -public: - XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {} - XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f) - : offset_(offset), f_(f) {} - XGBOOST_DEVICE self_type operator+(const int &b) const { - return DiscardLambdaItr(offset_ + b, f_); - } - XGBOOST_DEVICE self_type operator++() { - offset_++; - return *this; - } - XGBOOST_DEVICE self_type operator++(int) { - self_type retval = *this; - offset_++; - return retval; - } - XGBOOST_DEVICE self_type &operator+=(const int &b) { - offset_ += b; - return *this; - } - XGBOOST_DEVICE reference operator*() const { - return LauncherItr(offset_, f_); - } - XGBOOST_DEVICE reference operator[](int idx) { - self_type offset = (*this) + idx; - return *offset; - } -}; - } // namespace dh diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 80b777aa8..ca25cba25 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -44,7 +44,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { size_t Size() const { return num_elements_; } __device__ COOTuple GetElement(size_t idx) const { size_t column_idx = - dh::UpperBound(column_ptr_.data(), column_ptr_.size(), idx) - 1; + thrust::upper_bound(thrust::seq,column_ptr_.begin(), column_ptr_.end(), idx) - column_ptr_.begin() - 1; auto& column = columns_[column_idx]; size_t row_idx = idx - column_ptr_[column_idx]; float value = column.valid.Data() == nullptr || column.valid.Check(row_idx) diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 30a46873a..91f25b7a7 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -49,7 +49,9 @@ __global__ void CompressBinEllpackKernel( int ncuts = cut_rows[feature + 1] - cut_rows[feature]; // Assigning the bin in current entry. // S.t.: fvalue < feature_cuts[bin] - bin = dh::UpperBound(feature_cuts, ncuts, fvalue); + bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts, + fvalue) - + feature_cuts; if (bin >= ncuts) { bin = ncuts - 1; } diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index 5d92d3b54..30cfdb08a 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -52,14 +52,18 @@ struct LambdaRankParam : public XGBoostParameter { template XGBOOST_DEVICE __forceinline__ uint32_t -CountNumItemsToTheLeftOf(const T * __restrict__ items, uint32_t n, T v) { - return dh::LowerBound(items, n, v, thrust::greater()); +CountNumItemsToTheLeftOf(const T *__restrict__ items, uint32_t n, T v) { + return thrust::lower_bound(thrust::seq, items, items + n, v, + thrust::greater()) - + items; } template XGBOOST_DEVICE __forceinline__ uint32_t -CountNumItemsToTheRightOf(const T * __restrict__ items, uint32_t n, T v) { - return n - dh::UpperBound(items, n, v, thrust::greater()); +CountNumItemsToTheRightOf(const T *__restrict__ items, uint32_t n, T v) { + return n - (thrust::upper_bound(thrust::seq, items, items + n, v, + thrust::greater()) - + items); } #endif @@ -671,7 +675,10 @@ class SortedLabelList : dh::SegmentSorter { dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) { // First, determine the group 'idx' belongs to uint32_t item_idx = idx % total_items; - uint32_t group_idx = dh::UpperBound(dgroups.data(), ngroups, item_idx); + uint32_t group_idx = + thrust::upper_bound(thrust::seq, dgroups.begin(), + dgroups.begin() + ngroups, item_idx) - + dgroups.begin(); // Span of this group within the larger labels/predictions sorted tuple uint32_t group_begin = dgroups[group_idx - 1]; uint32_t group_end = dgroups[group_idx]; diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index 7427362e9..e8f55fee2 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -1,6 +1,8 @@ /*! * Copyright 2017-2019 XGBoost contributors */ +#include +#include #include #include #include "../../common/device_helpers.cuh" @@ -11,58 +13,74 @@ namespace tree { struct IndicateLeftTransform { bst_node_t left_nidx; - explicit IndicateLeftTransform(bst_node_t left_nidx) - : left_nidx(left_nidx) {} - __host__ __device__ __forceinline__ int operator()(const bst_node_t& x) const { + explicit IndicateLeftTransform(bst_node_t left_nidx) : left_nidx(left_nidx) {} + __host__ __device__ __forceinline__ size_t + operator()(const bst_node_t& x) const { return x == left_nidx ? 1 : 0; } }; -/* - * position: Position of rows belonged to current split node. - */ + +struct IndexFlagTuple { + size_t idx; + size_t flag; +}; + +struct IndexFlagOp { + __device__ IndexFlagTuple operator()(const IndexFlagTuple& a, + const IndexFlagTuple& b) const { + return {b.idx, a.flag + b.flag}; + } +}; + +struct WriteResultsFunctor { + bst_node_t left_nidx; + common::Span position_in; + common::Span position_out; + common::Span ridx_in; + common::Span ridx_out; + int64_t* d_left_count; + + __device__ int operator()(const IndexFlagTuple& x) { + // the ex_scan_result represents how many rows have been assigned to left + // node so far during scan. + int scatter_address; + if (position_in[x.idx] == left_nidx) { + scatter_address = x.flag - 1; // -1 because inclusive scan + } else { + // current number of rows belong to right node + total number of rows + // belong to left node + scatter_address = (x.idx - x.flag) + *d_left_count; + } + // copy the node id to output + position_out[scatter_address] = position_in[x.idx]; + ridx_out[scatter_address] = ridx_in[x.idx]; + + // Discard + return 0; + } +}; + void RowPartitioner::SortPosition(common::Span position, common::Span position_out, common::Span ridx, common::Span ridx_out, - bst_node_t left_nidx, - bst_node_t right_nidx, + bst_node_t left_nidx, bst_node_t right_nidx, int64_t* d_left_count, cudaStream_t stream) { - // radix sort over 1 bit, see: - // https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch39.html - auto d_position_out = position_out.data(); - auto d_position_in = position.data(); - auto d_ridx_out = ridx_out.data(); - auto d_ridx_in = ridx.data(); - auto write_results = [=] __device__(size_t idx, int ex_scan_result) { - // the ex_scan_result represents how many rows have been assigned to left node so far - // during scan. - int scatter_address; - if (d_position_in[idx] == left_nidx) { - scatter_address = ex_scan_result; - } else { - // current number of rows belong to right node + total number of rows belong to left - // node - scatter_address = (idx - ex_scan_result) + *d_left_count; - } - // copy the node id to output - d_position_out[scatter_address] = d_position_in[idx]; - d_ridx_out[scatter_address] = d_ridx_in[idx]; - }; // NOLINT - - IndicateLeftTransform is_left(left_nidx); - // an iterator that given a old position returns whether it belongs to left or right - // node. - cub::TransformInputIterator - in_itr(d_position_in, is_left); - dh::DiscardLambdaItr out_itr(write_results); - size_t temp_storage_bytes = 0; - // position is of the same size with current split node's row segment - cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr, - position.size(), stream); - dh::caching_device_vector temp_storage(temp_storage_bytes); - cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes, - in_itr, out_itr, position.size(), stream); + WriteResultsFunctor write_results{left_nidx, position, position_out, + ridx, ridx_out, d_left_count}; + auto discard_write_iterator = thrust::make_transform_output_iterator( + thrust::discard_iterator(), write_results); + auto input_iterator = dh::MakeTransformIterator( + thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { + return IndexFlagTuple{idx, position[idx] == left_nidx}; + }); + dh::XGBCachingDeviceAllocator alloc; + thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator, + input_iterator + position.size(), + discard_write_iterator, + [=] __device__(IndexFlagTuple a, IndexFlagTuple b) { + return IndexFlagTuple{b.idx, a.flag + b.flag}; + }); } RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) @@ -137,7 +155,7 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment, SortPosition( // position_in common::Span(position_.Current() + segment.begin, - segment.Size()), + segment.Size()), // position_out common::Span(position_.Other() + segment.begin, segment.Size()), diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 119cc6fe4..b10ea0c53 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -8,25 +8,6 @@ #include "../helpers.h" #include "gtest/gtest.h" -using xgboost::common::Span; - -void CreateTestData(xgboost::bst_uint num_rows, int max_row_size, - thrust::host_vector *row_ptr, - thrust::host_vector *rows) { - row_ptr->resize(num_rows + 1); - int sum = 0; - for (xgboost::bst_uint i = 0; i <= num_rows; i++) { - (*row_ptr)[i] = sum; - sum += rand() % max_row_size; // NOLINT - - if (i < num_rows) { - for (int j = (*row_ptr)[i]; j < sum; j++) { - (*rows).push_back(i); - } - } - } -} - TEST(SumReduce, Test) { thrust::device_vector data(100, 1.0f); dh::CubMemory temp; @@ -34,80 +15,3 @@ TEST(SumReduce, Test) { ASSERT_NEAR(sum, 100.0f, 1e-5); } -template > -void TestUpperBoundImpl(const std::vector &vec, T val_to_find, - const Comp &comp = Comp()) { - EXPECT_EQ(dh::UpperBound(vec.data(), vec.size(), val_to_find, comp), - std::upper_bound(vec.begin(), vec.end(), val_to_find, comp) - vec.begin()); -} - -template > -void TestLowerBoundImpl(const std::vector &vec, T val_to_find, - const Comp &comp = Comp()) { - EXPECT_EQ(dh::LowerBound(vec.data(), vec.size(), val_to_find, comp), - std::lower_bound(vec.begin(), vec.end(), val_to_find, comp) - vec.begin()); -} - -TEST(UpperBound, DataAscending) { - std::vector hvec{0, 3, 5, 5, 7, 8, 9, 10, 10}; - - // Test boundary conditions - TestUpperBoundImpl(hvec, hvec.front()); // Result 1 - TestUpperBoundImpl(hvec, hvec.front() - 1); // Result 0 - TestUpperBoundImpl(hvec, hvec.back() + 1); // Result hvec.size() - TestUpperBoundImpl(hvec, hvec.back()); // Result hvec.size() - - // Test other values - both missing and present - TestUpperBoundImpl(hvec, 3); // Result 2 - TestUpperBoundImpl(hvec, 4); // Result 2 - TestUpperBoundImpl(hvec, 5); // Result 4 -} - -TEST(UpperBound, DataDescending) { - std::vector hvec{10, 10, 9, 8, 7, 5, 5, 3, 0, 0}; - const auto &comparator = thrust::greater(); - - // Test boundary conditions - TestUpperBoundImpl(hvec, hvec.front(), comparator); // Result 2 - TestUpperBoundImpl(hvec, hvec.front() + 1, comparator); // Result 0 - TestUpperBoundImpl(hvec, hvec.back(), comparator); // Result hvec.size() - TestUpperBoundImpl(hvec, hvec.back() - 1, comparator); // Result hvec.size() - - // Test other values - both missing and present - TestUpperBoundImpl(hvec, 9, comparator); // Result 3 - TestUpperBoundImpl(hvec, 7, comparator); // Result 5 - TestUpperBoundImpl(hvec, 4, comparator); // Result 7 - TestUpperBoundImpl(hvec, 8, comparator); // Result 4 -} - -TEST(LowerBound, DataAscending) { - std::vector hvec{0, 3, 5, 5, 7, 8, 9, 10, 10}; - - // Test boundary conditions - TestLowerBoundImpl(hvec, hvec.front()); // Result 0 - TestLowerBoundImpl(hvec, hvec.front() - 1); // Result 0 - TestLowerBoundImpl(hvec, hvec.back()); // Result 7 - TestLowerBoundImpl(hvec, hvec.back() + 1); // Result hvec.size() - - // Test other values - both missing and present - TestLowerBoundImpl(hvec, 3); // Result 1 - TestLowerBoundImpl(hvec, 4); // Result 2 - TestLowerBoundImpl(hvec, 5); // Result 2 -} - -TEST(LowerBound, DataDescending) { - std::vector hvec{10, 10, 9, 8, 7, 5, 5, 3, 0, 0}; - const auto &comparator = thrust::greater(); - - // Test boundary conditions - TestLowerBoundImpl(hvec, hvec.front(), comparator); // Result 0 - TestLowerBoundImpl(hvec, hvec.front() + 1, comparator); // Result 0 - TestLowerBoundImpl(hvec, hvec.back(), comparator); // Result 8 - TestLowerBoundImpl(hvec, hvec.back() - 1, comparator); // Result hvec.size() - - // Test other values - both missing and present - TestLowerBoundImpl(hvec, 9, comparator); // Result 2 - TestLowerBoundImpl(hvec, 7, comparator); // Result 4 - TestLowerBoundImpl(hvec, 4, comparator); // Result 7 - TestLowerBoundImpl(hvec, 8, comparator); // Result 3 -}