Use thrust functions instead of custom functions (#5544)
This commit is contained in:
@@ -9,15 +9,14 @@
|
||||
#include <thrust/system_error.h>
|
||||
#include <thrust/logical.h>
|
||||
#include <thrust/gather.h>
|
||||
#include <thrust/binary_search.h>
|
||||
|
||||
#include <omp.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_allocator.cuh>
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -28,7 +27,6 @@
|
||||
#include "xgboost/span.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "timer.h"
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
@@ -132,94 +130,6 @@ DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, un
|
||||
static_cast<unsigned int>(b) << (ibyte % (sizeof(unsigned int)) * 8));
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Items of size 'n' are sorted in an order determined by the Comparator
|
||||
// If left is true, find the number of elements where 'comp(item, v)' returns true;
|
||||
// 0 if nothing is true
|
||||
// If left is false, find the number of elements where '!comp(item, v)' returns true;
|
||||
// 0 if nothing is true
|
||||
template <typename T, typename Comparator = thrust::greater<T>>
|
||||
XGBOOST_DEVICE __forceinline__ uint32_t
|
||||
CountNumItemsImpl(bool left, const T * __restrict__ items, uint32_t n, T v,
|
||||
const Comparator &comp = Comparator()) {
|
||||
const T *items_begin = items;
|
||||
uint32_t num_remaining = n;
|
||||
const T *middle_item = nullptr;
|
||||
uint32_t middle;
|
||||
while (num_remaining > 0) {
|
||||
middle_item = items_begin;
|
||||
middle = num_remaining / 2;
|
||||
middle_item += middle;
|
||||
if ((left && comp(*middle_item, v)) || (!left && !comp(v, *middle_item))) {
|
||||
items_begin = ++middle_item;
|
||||
num_remaining -= middle + 1;
|
||||
} else {
|
||||
num_remaining = middle;
|
||||
}
|
||||
}
|
||||
|
||||
return left ? items_begin - items : items + n - items_begin;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
/*!
|
||||
* \brief Find the strict upper bound for an element in a sorted array
|
||||
* using binary search.
|
||||
* \param items pointer to the first element of the sorted array
|
||||
* \param n length of the sorted array
|
||||
* \param v value for which to find the upper bound
|
||||
* \param comp determines how the items are sorted ascending/descending order - should conform
|
||||
* to ordering semantics
|
||||
* \return the smallest index i that has a value > v, or n if none is larger when sorted ascendingly
|
||||
* or, an index i with a value < v, or 0 if none is smaller when sorted descendingly
|
||||
*/
|
||||
// Preserve existing default behavior of upper bound
|
||||
template <typename T, typename Comp = thrust::less<T>>
|
||||
XGBOOST_DEVICE __forceinline__ uint32_t UpperBound(const T *__restrict__ items,
|
||||
uint32_t n,
|
||||
T v,
|
||||
const Comp &comp = Comp()) {
|
||||
if (std::is_same<Comp, thrust::less<T>>::value ||
|
||||
std::is_same<Comp, thrust::greater<T>>::value) {
|
||||
return n - internal::CountNumItemsImpl(false, items, n, v, comp);
|
||||
} else {
|
||||
static_assert(std::is_same<Comp, thrust::less<T>>::value ||
|
||||
std::is_same<Comp, thrust::greater<T>>::value,
|
||||
"Invalid comparator used in Upperbound - can only be thrust::greater/less");
|
||||
return std::numeric_limits<uint32_t>::max(); // Simply to quiesce the compiler
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Find the strict lower bound for an element in a sorted array
|
||||
* using binary search.
|
||||
* \param items pointer to the first element of the sorted array
|
||||
* \param n length of the sorted array
|
||||
* \param v value for which to find the upper bound
|
||||
* \param comp determines how the items are sorted ascending/descending order - should conform
|
||||
* to ordering semantics
|
||||
* \return the smallest index i that has a value >= v, or n if none is larger
|
||||
* when sorted ascendingly
|
||||
* or, an index i with a value <= v, or 0 if none is smaller when sorted descendingly
|
||||
*/
|
||||
template <typename T, typename Comp = thrust::less<T>>
|
||||
XGBOOST_DEVICE __forceinline__ uint32_t LowerBound(const T *__restrict__ items,
|
||||
uint32_t n,
|
||||
T v,
|
||||
const Comp &comp = Comp()) {
|
||||
if (std::is_same<Comp, thrust::less<T>>::value ||
|
||||
std::is_same<Comp, thrust::greater<T>>::value) {
|
||||
return internal::CountNumItemsImpl(true, items, n, v, comp);
|
||||
} else {
|
||||
static_assert(std::is_same<Comp, thrust::less<T>>::value ||
|
||||
std::is_same<Comp, thrust::greater<T>>::value,
|
||||
"Invalid comparator used in LowerBound - can only be thrust::greater/less");
|
||||
return std::numeric_limits<uint32_t>::max(); // Simply to quiesce the compiler
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ xgboost::common::Range GridStrideRange(T begin, T end) {
|
||||
begin += blockDim.x * blockIdx.x + threadIdx.x;
|
||||
@@ -878,7 +788,8 @@ class SegmentSorter {
|
||||
const uint32_t *dgroups = dgroups_.data().get();
|
||||
uint32_t ngroups = dgroups_.size();
|
||||
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
|
||||
return dh::UpperBound(dgroups, ngroups, idx) - 1;
|
||||
return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) -
|
||||
dgroups - 1;
|
||||
}; // NOLINT
|
||||
|
||||
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
||||
@@ -1018,70 +929,4 @@ thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
template <typename FunctionT>
|
||||
class LauncherItr {
|
||||
public:
|
||||
int idx { 0 };
|
||||
FunctionT f;
|
||||
XGBOOST_DEVICE LauncherItr() : idx(0) {} // NOLINT
|
||||
XGBOOST_DEVICE LauncherItr(int idx, FunctionT f) : idx(idx), f(f) {}
|
||||
XGBOOST_DEVICE LauncherItr &operator=(int output) {
|
||||
f(idx, output);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda
|
||||
* with the index of the output and the algorithm output as arguments.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*
|
||||
* \tparam FunctionT Type of the function t.
|
||||
*/
|
||||
template <typename FunctionT>
|
||||
class DiscardLambdaItr {
|
||||
public:
|
||||
// Required iterator traits
|
||||
using self_type = DiscardLambdaItr; // NOLINT
|
||||
using difference_type = ptrdiff_t; // NOLINT
|
||||
using value_type = void; // NOLINT
|
||||
using pointer = value_type *; // NOLINT
|
||||
using reference = LauncherItr<FunctionT>; // NOLINT
|
||||
using iterator_category = typename thrust::detail::iterator_facade_category< // NOLINT
|
||||
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
|
||||
reference>::type; // NOLINT
|
||||
private:
|
||||
difference_type offset_;
|
||||
FunctionT f_;
|
||||
public:
|
||||
XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {}
|
||||
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f)
|
||||
: offset_(offset), f_(f) {}
|
||||
XGBOOST_DEVICE self_type operator+(const int &b) const {
|
||||
return DiscardLambdaItr(offset_ + b, f_);
|
||||
}
|
||||
XGBOOST_DEVICE self_type operator++() {
|
||||
offset_++;
|
||||
return *this;
|
||||
}
|
||||
XGBOOST_DEVICE self_type operator++(int) {
|
||||
self_type retval = *this;
|
||||
offset_++;
|
||||
return retval;
|
||||
}
|
||||
XGBOOST_DEVICE self_type &operator+=(const int &b) {
|
||||
offset_ += b;
|
||||
return *this;
|
||||
}
|
||||
XGBOOST_DEVICE reference operator*() const {
|
||||
return LauncherItr<FunctionT>(offset_, f_);
|
||||
}
|
||||
XGBOOST_DEVICE reference operator[](int idx) {
|
||||
self_type offset = (*this) + idx;
|
||||
return *offset;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace dh
|
||||
|
||||
Reference in New Issue
Block a user