Use thrust functions instead of custom functions (#5544)

This commit is contained in:
Rory Mitchell 2020-04-16 21:41:16 +12:00 committed by GitHub
parent 6a169cd41a
commit e268fb0093
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 82 additions and 306 deletions

View File

@ -9,15 +9,14 @@
#include <thrust/system_error.h> #include <thrust/system_error.h>
#include <thrust/logical.h> #include <thrust/logical.h>
#include <thrust/gather.h> #include <thrust/gather.h>
#include <thrust/binary_search.h>
#include <omp.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cub/util_allocator.cuh> #include <cub/util_allocator.cuh>
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
#include <ctime>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
@ -28,7 +27,6 @@
#include "xgboost/span.h" #include "xgboost/span.h"
#include "common.h" #include "common.h"
#include "timer.h"
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
#include "nccl.h" #include "nccl.h"
@ -132,94 +130,6 @@ DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, un
static_cast<unsigned int>(b) << (ibyte % (sizeof(unsigned int)) * 8)); static_cast<unsigned int>(b) << (ibyte % (sizeof(unsigned int)) * 8));
} }
namespace internal {
// Items of size 'n' are sorted in an order determined by the Comparator
// If left is true, find the number of elements where 'comp(item, v)' returns true;
// 0 if nothing is true
// If left is false, find the number of elements where '!comp(item, v)' returns true;
// 0 if nothing is true
template <typename T, typename Comparator = thrust::greater<T>>
XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsImpl(bool left, const T * __restrict__ items, uint32_t n, T v,
const Comparator &comp = Comparator()) {
const T *items_begin = items;
uint32_t num_remaining = n;
const T *middle_item = nullptr;
uint32_t middle;
while (num_remaining > 0) {
middle_item = items_begin;
middle = num_remaining / 2;
middle_item += middle;
if ((left && comp(*middle_item, v)) || (!left && !comp(v, *middle_item))) {
items_begin = ++middle_item;
num_remaining -= middle + 1;
} else {
num_remaining = middle;
}
}
return left ? items_begin - items : items + n - items_begin;
}
} // namespace internal
/*!
* \brief Find the strict upper bound for an element in a sorted array
* using binary search.
* \param items pointer to the first element of the sorted array
* \param n length of the sorted array
* \param v value for which to find the upper bound
* \param comp determines how the items are sorted ascending/descending order - should conform
* to ordering semantics
* \return the smallest index i that has a value > v, or n if none is larger when sorted ascendingly
* or, an index i with a value < v, or 0 if none is smaller when sorted descendingly
*/
// Preserve existing default behavior of upper bound
template <typename T, typename Comp = thrust::less<T>>
XGBOOST_DEVICE __forceinline__ uint32_t UpperBound(const T *__restrict__ items,
uint32_t n,
T v,
const Comp &comp = Comp()) {
if (std::is_same<Comp, thrust::less<T>>::value ||
std::is_same<Comp, thrust::greater<T>>::value) {
return n - internal::CountNumItemsImpl(false, items, n, v, comp);
} else {
static_assert(std::is_same<Comp, thrust::less<T>>::value ||
std::is_same<Comp, thrust::greater<T>>::value,
"Invalid comparator used in Upperbound - can only be thrust::greater/less");
return std::numeric_limits<uint32_t>::max(); // Simply to quiesce the compiler
}
}
/*!
* \brief Find the strict lower bound for an element in a sorted array
* using binary search.
* \param items pointer to the first element of the sorted array
* \param n length of the sorted array
* \param v value for which to find the upper bound
* \param comp determines how the items are sorted ascending/descending order - should conform
* to ordering semantics
* \return the smallest index i that has a value >= v, or n if none is larger
* when sorted ascendingly
* or, an index i with a value <= v, or 0 if none is smaller when sorted descendingly
*/
template <typename T, typename Comp = thrust::less<T>>
XGBOOST_DEVICE __forceinline__ uint32_t LowerBound(const T *__restrict__ items,
uint32_t n,
T v,
const Comp &comp = Comp()) {
if (std::is_same<Comp, thrust::less<T>>::value ||
std::is_same<Comp, thrust::greater<T>>::value) {
return internal::CountNumItemsImpl(true, items, n, v, comp);
} else {
static_assert(std::is_same<Comp, thrust::less<T>>::value ||
std::is_same<Comp, thrust::greater<T>>::value,
"Invalid comparator used in LowerBound - can only be thrust::greater/less");
return std::numeric_limits<uint32_t>::max(); // Simply to quiesce the compiler
}
}
template <typename T> template <typename T>
__device__ xgboost::common::Range GridStrideRange(T begin, T end) { __device__ xgboost::common::Range GridStrideRange(T begin, T end) {
begin += blockDim.x * blockIdx.x + threadIdx.x; begin += blockDim.x * blockIdx.x + threadIdx.x;
@ -878,7 +788,8 @@ class SegmentSorter {
const uint32_t *dgroups = dgroups_.data().get(); const uint32_t *dgroups = dgroups_.data().get();
uint32_t ngroups = dgroups_.size(); uint32_t ngroups = dgroups_.size();
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) { auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
return dh::UpperBound(dgroups, ngroups, idx) - 1; return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) -
dgroups - 1;
}; // NOLINT }; // NOLINT
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)), thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
@ -1018,70 +929,4 @@ thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func); return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
} }
template <typename FunctionT>
class LauncherItr {
public:
int idx { 0 };
FunctionT f;
XGBOOST_DEVICE LauncherItr() : idx(0) {} // NOLINT
XGBOOST_DEVICE LauncherItr(int idx, FunctionT f) : idx(idx), f(f) {}
XGBOOST_DEVICE LauncherItr &operator=(int output) {
f(idx, output);
return *this;
}
};
/**
* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda
* with the index of the output and the algorithm output as arguments.
*
* \author Rory
* \date 7/9/2017
*
* \tparam FunctionT Type of the function t.
*/
template <typename FunctionT>
class DiscardLambdaItr {
public:
// Required iterator traits
using self_type = DiscardLambdaItr; // NOLINT
using difference_type = ptrdiff_t; // NOLINT
using value_type = void; // NOLINT
using pointer = value_type *; // NOLINT
using reference = LauncherItr<FunctionT>; // NOLINT
using iterator_category = typename thrust::detail::iterator_facade_category< // NOLINT
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
reference>::type; // NOLINT
private:
difference_type offset_;
FunctionT f_;
public:
XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {}
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f)
: offset_(offset), f_(f) {}
XGBOOST_DEVICE self_type operator+(const int &b) const {
return DiscardLambdaItr(offset_ + b, f_);
}
XGBOOST_DEVICE self_type operator++() {
offset_++;
return *this;
}
XGBOOST_DEVICE self_type operator++(int) {
self_type retval = *this;
offset_++;
return retval;
}
XGBOOST_DEVICE self_type &operator+=(const int &b) {
offset_ += b;
return *this;
}
XGBOOST_DEVICE reference operator*() const {
return LauncherItr<FunctionT>(offset_, f_);
}
XGBOOST_DEVICE reference operator[](int idx) {
self_type offset = (*this) + idx;
return *offset;
}
};
} // namespace dh } // namespace dh

View File

@ -44,7 +44,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
size_t Size() const { return num_elements_; } size_t Size() const { return num_elements_; }
__device__ COOTuple GetElement(size_t idx) const { __device__ COOTuple GetElement(size_t idx) const {
size_t column_idx = size_t column_idx =
dh::UpperBound(column_ptr_.data(), column_ptr_.size(), idx) - 1; thrust::upper_bound(thrust::seq,column_ptr_.begin(), column_ptr_.end(), idx) - column_ptr_.begin() - 1;
auto& column = columns_[column_idx]; auto& column = columns_[column_idx];
size_t row_idx = idx - column_ptr_[column_idx]; size_t row_idx = idx - column_ptr_[column_idx];
float value = column.valid.Data() == nullptr || column.valid.Check(row_idx) float value = column.valid.Data() == nullptr || column.valid.Check(row_idx)

View File

@ -49,7 +49,9 @@ __global__ void CompressBinEllpackKernel(
int ncuts = cut_rows[feature + 1] - cut_rows[feature]; int ncuts = cut_rows[feature + 1] - cut_rows[feature];
// Assigning the bin in current entry. // Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin] // S.t.: fvalue < feature_cuts[bin]
bin = dh::UpperBound(feature_cuts, ncuts, fvalue); bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts,
fvalue) -
feature_cuts;
if (bin >= ncuts) { if (bin >= ncuts) {
bin = ncuts - 1; bin = ncuts - 1;
} }

View File

@ -53,13 +53,17 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
template <typename T> template <typename T>
XGBOOST_DEVICE __forceinline__ uint32_t XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsToTheLeftOf(const T *__restrict__ items, uint32_t n, T v) { CountNumItemsToTheLeftOf(const T *__restrict__ items, uint32_t n, T v) {
return dh::LowerBound(items, n, v, thrust::greater<T>()); return thrust::lower_bound(thrust::seq, items, items + n, v,
thrust::greater<T>()) -
items;
} }
template <typename T> template <typename T>
XGBOOST_DEVICE __forceinline__ uint32_t XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsToTheRightOf(const T *__restrict__ items, uint32_t n, T v) { CountNumItemsToTheRightOf(const T *__restrict__ items, uint32_t n, T v) {
return n - dh::UpperBound(items, n, v, thrust::greater<T>()); return n - (thrust::upper_bound(thrust::seq, items, items + n, v,
thrust::greater<T>()) -
items);
} }
#endif #endif
@ -671,7 +675,10 @@ class SortedLabelList : dh::SegmentSorter<float> {
dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) { dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) {
// First, determine the group 'idx' belongs to // First, determine the group 'idx' belongs to
uint32_t item_idx = idx % total_items; uint32_t item_idx = idx % total_items;
uint32_t group_idx = dh::UpperBound(dgroups.data(), ngroups, item_idx); uint32_t group_idx =
thrust::upper_bound(thrust::seq, dgroups.begin(),
dgroups.begin() + ngroups, item_idx) -
dgroups.begin();
// Span of this group within the larger labels/predictions sorted tuple // Span of this group within the larger labels/predictions sorted tuple
uint32_t group_begin = dgroups[group_idx - 1]; uint32_t group_begin = dgroups[group_idx - 1];
uint32_t group_end = dgroups[group_idx]; uint32_t group_end = dgroups[group_idx];

View File

@ -1,6 +1,8 @@
/*! /*!
* Copyright 2017-2019 XGBoost contributors * Copyright 2017-2019 XGBoost contributors
*/ */
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <vector> #include <vector>
#include "../../common/device_helpers.cuh" #include "../../common/device_helpers.cuh"
@ -11,58 +13,74 @@ namespace tree {
struct IndicateLeftTransform { struct IndicateLeftTransform {
bst_node_t left_nidx; bst_node_t left_nidx;
explicit IndicateLeftTransform(bst_node_t left_nidx) explicit IndicateLeftTransform(bst_node_t left_nidx) : left_nidx(left_nidx) {}
: left_nidx(left_nidx) {} __host__ __device__ __forceinline__ size_t
__host__ __device__ __forceinline__ int operator()(const bst_node_t& x) const { operator()(const bst_node_t& x) const {
return x == left_nidx ? 1 : 0; return x == left_nidx ? 1 : 0;
} }
}; };
/*
* position: Position of rows belonged to current split node. struct IndexFlagTuple {
*/ size_t idx;
size_t flag;
};
struct IndexFlagOp {
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a,
const IndexFlagTuple& b) const {
return {b.idx, a.flag + b.flag};
}
};
struct WriteResultsFunctor {
bst_node_t left_nidx;
common::Span<bst_node_t> position_in;
common::Span<bst_node_t> position_out;
common::Span<RowPartitioner::RowIndexT> ridx_in;
common::Span<RowPartitioner::RowIndexT> ridx_out;
int64_t* d_left_count;
__device__ int operator()(const IndexFlagTuple& x) {
// the ex_scan_result represents how many rows have been assigned to left
// node so far during scan.
int scatter_address;
if (position_in[x.idx] == left_nidx) {
scatter_address = x.flag - 1; // -1 because inclusive scan
} else {
// current number of rows belong to right node + total number of rows
// belong to left node
scatter_address = (x.idx - x.flag) + *d_left_count;
}
// copy the node id to output
position_out[scatter_address] = position_in[x.idx];
ridx_out[scatter_address] = ridx_in[x.idx];
// Discard
return 0;
}
};
void RowPartitioner::SortPosition(common::Span<bst_node_t> position, void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out, common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out, common::Span<RowIndexT> ridx_out,
bst_node_t left_nidx, bst_node_t left_nidx, bst_node_t right_nidx,
bst_node_t right_nidx,
int64_t* d_left_count, cudaStream_t stream) { int64_t* d_left_count, cudaStream_t stream) {
// radix sort over 1 bit, see: WriteResultsFunctor write_results{left_nidx, position, position_out,
// https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch39.html ridx, ridx_out, d_left_count};
auto d_position_out = position_out.data(); auto discard_write_iterator = thrust::make_transform_output_iterator(
auto d_position_in = position.data(); thrust::discard_iterator<int>(), write_results);
auto d_ridx_out = ridx_out.data(); auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
auto d_ridx_in = ridx.data(); thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
auto write_results = [=] __device__(size_t idx, int ex_scan_result) { return IndexFlagTuple{idx, position[idx] == left_nidx};
// the ex_scan_result represents how many rows have been assigned to left node so far });
// during scan. dh::XGBCachingDeviceAllocator<char> alloc;
int scatter_address; thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator,
if (d_position_in[idx] == left_nidx) { input_iterator + position.size(),
scatter_address = ex_scan_result; discard_write_iterator,
} else { [=] __device__(IndexFlagTuple a, IndexFlagTuple b) {
// current number of rows belong to right node + total number of rows belong to left return IndexFlagTuple{b.idx, a.flag + b.flag};
// node });
scatter_address = (idx - ex_scan_result) + *d_left_count;
}
// copy the node id to output
d_position_out[scatter_address] = d_position_in[idx];
d_ridx_out[scatter_address] = d_ridx_in[idx];
}; // NOLINT
IndicateLeftTransform is_left(left_nidx);
// an iterator that given a old position returns whether it belongs to left or right
// node.
cub::TransformInputIterator<bst_node_t, IndicateLeftTransform,
bst_node_t*>
in_itr(d_position_in, is_left);
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
size_t temp_storage_bytes = 0;
// position is of the same size with current split node's row segment
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
position.size(), stream);
dh::caching_device_vector<uint8_t> temp_storage(temp_storage_bytes);
cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes,
in_itr, out_itr, position.size(), stream);
} }
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)

View File

@ -8,25 +8,6 @@
#include "../helpers.h" #include "../helpers.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
using xgboost::common::Span;
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
thrust::host_vector<int> *row_ptr,
thrust::host_vector<xgboost::bst_uint> *rows) {
row_ptr->resize(num_rows + 1);
int sum = 0;
for (xgboost::bst_uint i = 0; i <= num_rows; i++) {
(*row_ptr)[i] = sum;
sum += rand() % max_row_size; // NOLINT
if (i < num_rows) {
for (int j = (*row_ptr)[i]; j < sum; j++) {
(*rows).push_back(i);
}
}
}
}
TEST(SumReduce, Test) { TEST(SumReduce, Test) {
thrust::device_vector<float> data(100, 1.0f); thrust::device_vector<float> data(100, 1.0f);
dh::CubMemory temp; dh::CubMemory temp;
@ -34,80 +15,3 @@ TEST(SumReduce, Test) {
ASSERT_NEAR(sum, 100.0f, 1e-5); ASSERT_NEAR(sum, 100.0f, 1e-5);
} }
template <typename T, typename Comp = thrust::less<T>>
void TestUpperBoundImpl(const std::vector<T> &vec, T val_to_find,
const Comp &comp = Comp()) {
EXPECT_EQ(dh::UpperBound(vec.data(), vec.size(), val_to_find, comp),
std::upper_bound(vec.begin(), vec.end(), val_to_find, comp) - vec.begin());
}
template <typename T, typename Comp = thrust::less<T>>
void TestLowerBoundImpl(const std::vector<T> &vec, T val_to_find,
const Comp &comp = Comp()) {
EXPECT_EQ(dh::LowerBound(vec.data(), vec.size(), val_to_find, comp),
std::lower_bound(vec.begin(), vec.end(), val_to_find, comp) - vec.begin());
}
TEST(UpperBound, DataAscending) {
std::vector<int> hvec{0, 3, 5, 5, 7, 8, 9, 10, 10};
// Test boundary conditions
TestUpperBoundImpl(hvec, hvec.front()); // Result 1
TestUpperBoundImpl(hvec, hvec.front() - 1); // Result 0
TestUpperBoundImpl(hvec, hvec.back() + 1); // Result hvec.size()
TestUpperBoundImpl(hvec, hvec.back()); // Result hvec.size()
// Test other values - both missing and present
TestUpperBoundImpl(hvec, 3); // Result 2
TestUpperBoundImpl(hvec, 4); // Result 2
TestUpperBoundImpl(hvec, 5); // Result 4
}
TEST(UpperBound, DataDescending) {
std::vector<int> hvec{10, 10, 9, 8, 7, 5, 5, 3, 0, 0};
const auto &comparator = thrust::greater<int>();
// Test boundary conditions
TestUpperBoundImpl(hvec, hvec.front(), comparator); // Result 2
TestUpperBoundImpl(hvec, hvec.front() + 1, comparator); // Result 0
TestUpperBoundImpl(hvec, hvec.back(), comparator); // Result hvec.size()
TestUpperBoundImpl(hvec, hvec.back() - 1, comparator); // Result hvec.size()
// Test other values - both missing and present
TestUpperBoundImpl(hvec, 9, comparator); // Result 3
TestUpperBoundImpl(hvec, 7, comparator); // Result 5
TestUpperBoundImpl(hvec, 4, comparator); // Result 7
TestUpperBoundImpl(hvec, 8, comparator); // Result 4
}
TEST(LowerBound, DataAscending) {
std::vector<int> hvec{0, 3, 5, 5, 7, 8, 9, 10, 10};
// Test boundary conditions
TestLowerBoundImpl(hvec, hvec.front()); // Result 0
TestLowerBoundImpl(hvec, hvec.front() - 1); // Result 0
TestLowerBoundImpl(hvec, hvec.back()); // Result 7
TestLowerBoundImpl(hvec, hvec.back() + 1); // Result hvec.size()
// Test other values - both missing and present
TestLowerBoundImpl(hvec, 3); // Result 1
TestLowerBoundImpl(hvec, 4); // Result 2
TestLowerBoundImpl(hvec, 5); // Result 2
}
TEST(LowerBound, DataDescending) {
std::vector<int> hvec{10, 10, 9, 8, 7, 5, 5, 3, 0, 0};
const auto &comparator = thrust::greater<int>();
// Test boundary conditions
TestLowerBoundImpl(hvec, hvec.front(), comparator); // Result 0
TestLowerBoundImpl(hvec, hvec.front() + 1, comparator); // Result 0
TestLowerBoundImpl(hvec, hvec.back(), comparator); // Result 8
TestLowerBoundImpl(hvec, hvec.back() - 1, comparator); // Result hvec.size()
// Test other values - both missing and present
TestLowerBoundImpl(hvec, 9, comparator); // Result 2
TestLowerBoundImpl(hvec, 7, comparator); // Result 4
TestLowerBoundImpl(hvec, 4, comparator); // Result 7
TestLowerBoundImpl(hvec, 8, comparator); // Result 3
}