Use thrust functions instead of custom functions (#5544)
This commit is contained in:
parent
6a169cd41a
commit
e268fb0093
@ -9,15 +9,14 @@
|
|||||||
#include <thrust/system_error.h>
|
#include <thrust/system_error.h>
|
||||||
#include <thrust/logical.h>
|
#include <thrust/logical.h>
|
||||||
#include <thrust/gather.h>
|
#include <thrust/gather.h>
|
||||||
|
#include <thrust/binary_search.h>
|
||||||
|
|
||||||
#include <omp.h>
|
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
#include <cub/util_allocator.cuh>
|
#include <cub/util_allocator.cuh>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <ctime>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -28,7 +27,6 @@
|
|||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "timer.h"
|
|
||||||
|
|
||||||
#ifdef XGBOOST_USE_NCCL
|
#ifdef XGBOOST_USE_NCCL
|
||||||
#include "nccl.h"
|
#include "nccl.h"
|
||||||
@ -132,94 +130,6 @@ DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, un
|
|||||||
static_cast<unsigned int>(b) << (ibyte % (sizeof(unsigned int)) * 8));
|
static_cast<unsigned int>(b) << (ibyte % (sizeof(unsigned int)) * 8));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
// Items of size 'n' are sorted in an order determined by the Comparator
|
|
||||||
// If left is true, find the number of elements where 'comp(item, v)' returns true;
|
|
||||||
// 0 if nothing is true
|
|
||||||
// If left is false, find the number of elements where '!comp(item, v)' returns true;
|
|
||||||
// 0 if nothing is true
|
|
||||||
template <typename T, typename Comparator = thrust::greater<T>>
|
|
||||||
XGBOOST_DEVICE __forceinline__ uint32_t
|
|
||||||
CountNumItemsImpl(bool left, const T * __restrict__ items, uint32_t n, T v,
|
|
||||||
const Comparator &comp = Comparator()) {
|
|
||||||
const T *items_begin = items;
|
|
||||||
uint32_t num_remaining = n;
|
|
||||||
const T *middle_item = nullptr;
|
|
||||||
uint32_t middle;
|
|
||||||
while (num_remaining > 0) {
|
|
||||||
middle_item = items_begin;
|
|
||||||
middle = num_remaining / 2;
|
|
||||||
middle_item += middle;
|
|
||||||
if ((left && comp(*middle_item, v)) || (!left && !comp(v, *middle_item))) {
|
|
||||||
items_begin = ++middle_item;
|
|
||||||
num_remaining -= middle + 1;
|
|
||||||
} else {
|
|
||||||
num_remaining = middle;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return left ? items_begin - items : items + n - items_begin;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace internal
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Find the strict upper bound for an element in a sorted array
|
|
||||||
* using binary search.
|
|
||||||
* \param items pointer to the first element of the sorted array
|
|
||||||
* \param n length of the sorted array
|
|
||||||
* \param v value for which to find the upper bound
|
|
||||||
* \param comp determines how the items are sorted ascending/descending order - should conform
|
|
||||||
* to ordering semantics
|
|
||||||
* \return the smallest index i that has a value > v, or n if none is larger when sorted ascendingly
|
|
||||||
* or, an index i with a value < v, or 0 if none is smaller when sorted descendingly
|
|
||||||
*/
|
|
||||||
// Preserve existing default behavior of upper bound
|
|
||||||
template <typename T, typename Comp = thrust::less<T>>
|
|
||||||
XGBOOST_DEVICE __forceinline__ uint32_t UpperBound(const T *__restrict__ items,
|
|
||||||
uint32_t n,
|
|
||||||
T v,
|
|
||||||
const Comp &comp = Comp()) {
|
|
||||||
if (std::is_same<Comp, thrust::less<T>>::value ||
|
|
||||||
std::is_same<Comp, thrust::greater<T>>::value) {
|
|
||||||
return n - internal::CountNumItemsImpl(false, items, n, v, comp);
|
|
||||||
} else {
|
|
||||||
static_assert(std::is_same<Comp, thrust::less<T>>::value ||
|
|
||||||
std::is_same<Comp, thrust::greater<T>>::value,
|
|
||||||
"Invalid comparator used in Upperbound - can only be thrust::greater/less");
|
|
||||||
return std::numeric_limits<uint32_t>::max(); // Simply to quiesce the compiler
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Find the strict lower bound for an element in a sorted array
|
|
||||||
* using binary search.
|
|
||||||
* \param items pointer to the first element of the sorted array
|
|
||||||
* \param n length of the sorted array
|
|
||||||
* \param v value for which to find the upper bound
|
|
||||||
* \param comp determines how the items are sorted ascending/descending order - should conform
|
|
||||||
* to ordering semantics
|
|
||||||
* \return the smallest index i that has a value >= v, or n if none is larger
|
|
||||||
* when sorted ascendingly
|
|
||||||
* or, an index i with a value <= v, or 0 if none is smaller when sorted descendingly
|
|
||||||
*/
|
|
||||||
template <typename T, typename Comp = thrust::less<T>>
|
|
||||||
XGBOOST_DEVICE __forceinline__ uint32_t LowerBound(const T *__restrict__ items,
|
|
||||||
uint32_t n,
|
|
||||||
T v,
|
|
||||||
const Comp &comp = Comp()) {
|
|
||||||
if (std::is_same<Comp, thrust::less<T>>::value ||
|
|
||||||
std::is_same<Comp, thrust::greater<T>>::value) {
|
|
||||||
return internal::CountNumItemsImpl(true, items, n, v, comp);
|
|
||||||
} else {
|
|
||||||
static_assert(std::is_same<Comp, thrust::less<T>>::value ||
|
|
||||||
std::is_same<Comp, thrust::greater<T>>::value,
|
|
||||||
"Invalid comparator used in LowerBound - can only be thrust::greater/less");
|
|
||||||
return std::numeric_limits<uint32_t>::max(); // Simply to quiesce the compiler
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ xgboost::common::Range GridStrideRange(T begin, T end) {
|
__device__ xgboost::common::Range GridStrideRange(T begin, T end) {
|
||||||
begin += blockDim.x * blockIdx.x + threadIdx.x;
|
begin += blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
@ -878,7 +788,8 @@ class SegmentSorter {
|
|||||||
const uint32_t *dgroups = dgroups_.data().get();
|
const uint32_t *dgroups = dgroups_.data().get();
|
||||||
uint32_t ngroups = dgroups_.size();
|
uint32_t ngroups = dgroups_.size();
|
||||||
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
|
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
|
||||||
return dh::UpperBound(dgroups, ngroups, idx) - 1;
|
return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) -
|
||||||
|
dgroups - 1;
|
||||||
}; // NOLINT
|
}; // NOLINT
|
||||||
|
|
||||||
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
||||||
@ -1018,70 +929,4 @@ thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
|||||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename FunctionT>
|
|
||||||
class LauncherItr {
|
|
||||||
public:
|
|
||||||
int idx { 0 };
|
|
||||||
FunctionT f;
|
|
||||||
XGBOOST_DEVICE LauncherItr() : idx(0) {} // NOLINT
|
|
||||||
XGBOOST_DEVICE LauncherItr(int idx, FunctionT f) : idx(idx), f(f) {}
|
|
||||||
XGBOOST_DEVICE LauncherItr &operator=(int output) {
|
|
||||||
f(idx, output);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda
|
|
||||||
* with the index of the output and the algorithm output as arguments.
|
|
||||||
*
|
|
||||||
* \author Rory
|
|
||||||
* \date 7/9/2017
|
|
||||||
*
|
|
||||||
* \tparam FunctionT Type of the function t.
|
|
||||||
*/
|
|
||||||
template <typename FunctionT>
|
|
||||||
class DiscardLambdaItr {
|
|
||||||
public:
|
|
||||||
// Required iterator traits
|
|
||||||
using self_type = DiscardLambdaItr; // NOLINT
|
|
||||||
using difference_type = ptrdiff_t; // NOLINT
|
|
||||||
using value_type = void; // NOLINT
|
|
||||||
using pointer = value_type *; // NOLINT
|
|
||||||
using reference = LauncherItr<FunctionT>; // NOLINT
|
|
||||||
using iterator_category = typename thrust::detail::iterator_facade_category< // NOLINT
|
|
||||||
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
|
|
||||||
reference>::type; // NOLINT
|
|
||||||
private:
|
|
||||||
difference_type offset_;
|
|
||||||
FunctionT f_;
|
|
||||||
public:
|
|
||||||
XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {}
|
|
||||||
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f)
|
|
||||||
: offset_(offset), f_(f) {}
|
|
||||||
XGBOOST_DEVICE self_type operator+(const int &b) const {
|
|
||||||
return DiscardLambdaItr(offset_ + b, f_);
|
|
||||||
}
|
|
||||||
XGBOOST_DEVICE self_type operator++() {
|
|
||||||
offset_++;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
XGBOOST_DEVICE self_type operator++(int) {
|
|
||||||
self_type retval = *this;
|
|
||||||
offset_++;
|
|
||||||
return retval;
|
|
||||||
}
|
|
||||||
XGBOOST_DEVICE self_type &operator+=(const int &b) {
|
|
||||||
offset_ += b;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
XGBOOST_DEVICE reference operator*() const {
|
|
||||||
return LauncherItr<FunctionT>(offset_, f_);
|
|
||||||
}
|
|
||||||
XGBOOST_DEVICE reference operator[](int idx) {
|
|
||||||
self_type offset = (*this) + idx;
|
|
||||||
return *offset;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace dh
|
} // namespace dh
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
|||||||
size_t Size() const { return num_elements_; }
|
size_t Size() const { return num_elements_; }
|
||||||
__device__ COOTuple GetElement(size_t idx) const {
|
__device__ COOTuple GetElement(size_t idx) const {
|
||||||
size_t column_idx =
|
size_t column_idx =
|
||||||
dh::UpperBound(column_ptr_.data(), column_ptr_.size(), idx) - 1;
|
thrust::upper_bound(thrust::seq,column_ptr_.begin(), column_ptr_.end(), idx) - column_ptr_.begin() - 1;
|
||||||
auto& column = columns_[column_idx];
|
auto& column = columns_[column_idx];
|
||||||
size_t row_idx = idx - column_ptr_[column_idx];
|
size_t row_idx = idx - column_ptr_[column_idx];
|
||||||
float value = column.valid.Data() == nullptr || column.valid.Check(row_idx)
|
float value = column.valid.Data() == nullptr || column.valid.Check(row_idx)
|
||||||
|
|||||||
@ -49,7 +49,9 @@ __global__ void CompressBinEllpackKernel(
|
|||||||
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
|
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
|
||||||
// Assigning the bin in current entry.
|
// Assigning the bin in current entry.
|
||||||
// S.t.: fvalue < feature_cuts[bin]
|
// S.t.: fvalue < feature_cuts[bin]
|
||||||
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
|
bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts,
|
||||||
|
fvalue) -
|
||||||
|
feature_cuts;
|
||||||
if (bin >= ncuts) {
|
if (bin >= ncuts) {
|
||||||
bin = ncuts - 1;
|
bin = ncuts - 1;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -52,14 +52,18 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
XGBOOST_DEVICE __forceinline__ uint32_t
|
XGBOOST_DEVICE __forceinline__ uint32_t
|
||||||
CountNumItemsToTheLeftOf(const T * __restrict__ items, uint32_t n, T v) {
|
CountNumItemsToTheLeftOf(const T *__restrict__ items, uint32_t n, T v) {
|
||||||
return dh::LowerBound(items, n, v, thrust::greater<T>());
|
return thrust::lower_bound(thrust::seq, items, items + n, v,
|
||||||
|
thrust::greater<T>()) -
|
||||||
|
items;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
XGBOOST_DEVICE __forceinline__ uint32_t
|
XGBOOST_DEVICE __forceinline__ uint32_t
|
||||||
CountNumItemsToTheRightOf(const T * __restrict__ items, uint32_t n, T v) {
|
CountNumItemsToTheRightOf(const T *__restrict__ items, uint32_t n, T v) {
|
||||||
return n - dh::UpperBound(items, n, v, thrust::greater<T>());
|
return n - (thrust::upper_bound(thrust::seq, items, items + n, v,
|
||||||
|
thrust::greater<T>()) -
|
||||||
|
items);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -671,7 +675,10 @@ class SortedLabelList : dh::SegmentSorter<float> {
|
|||||||
dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) {
|
dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) {
|
||||||
// First, determine the group 'idx' belongs to
|
// First, determine the group 'idx' belongs to
|
||||||
uint32_t item_idx = idx % total_items;
|
uint32_t item_idx = idx % total_items;
|
||||||
uint32_t group_idx = dh::UpperBound(dgroups.data(), ngroups, item_idx);
|
uint32_t group_idx =
|
||||||
|
thrust::upper_bound(thrust::seq, dgroups.begin(),
|
||||||
|
dgroups.begin() + ngroups, item_idx) -
|
||||||
|
dgroups.begin();
|
||||||
// Span of this group within the larger labels/predictions sorted tuple
|
// Span of this group within the larger labels/predictions sorted tuple
|
||||||
uint32_t group_begin = dgroups[group_idx - 1];
|
uint32_t group_begin = dgroups[group_idx - 1];
|
||||||
uint32_t group_end = dgroups[group_idx];
|
uint32_t group_end = dgroups[group_idx];
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017-2019 XGBoost contributors
|
* Copyright 2017-2019 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
|
#include <thrust/iterator/discard_iterator.h>
|
||||||
|
#include <thrust/iterator/transform_output_iterator.h>
|
||||||
#include <thrust/sequence.h>
|
#include <thrust/sequence.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../../common/device_helpers.cuh"
|
#include "../../common/device_helpers.cuh"
|
||||||
@ -11,58 +13,74 @@ namespace tree {
|
|||||||
|
|
||||||
struct IndicateLeftTransform {
|
struct IndicateLeftTransform {
|
||||||
bst_node_t left_nidx;
|
bst_node_t left_nidx;
|
||||||
explicit IndicateLeftTransform(bst_node_t left_nidx)
|
explicit IndicateLeftTransform(bst_node_t left_nidx) : left_nidx(left_nidx) {}
|
||||||
: left_nidx(left_nidx) {}
|
__host__ __device__ __forceinline__ size_t
|
||||||
__host__ __device__ __forceinline__ int operator()(const bst_node_t& x) const {
|
operator()(const bst_node_t& x) const {
|
||||||
return x == left_nidx ? 1 : 0;
|
return x == left_nidx ? 1 : 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
/*
|
|
||||||
* position: Position of rows belonged to current split node.
|
struct IndexFlagTuple {
|
||||||
*/
|
size_t idx;
|
||||||
|
size_t flag;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct IndexFlagOp {
|
||||||
|
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a,
|
||||||
|
const IndexFlagTuple& b) const {
|
||||||
|
return {b.idx, a.flag + b.flag};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WriteResultsFunctor {
|
||||||
|
bst_node_t left_nidx;
|
||||||
|
common::Span<bst_node_t> position_in;
|
||||||
|
common::Span<bst_node_t> position_out;
|
||||||
|
common::Span<RowPartitioner::RowIndexT> ridx_in;
|
||||||
|
common::Span<RowPartitioner::RowIndexT> ridx_out;
|
||||||
|
int64_t* d_left_count;
|
||||||
|
|
||||||
|
__device__ int operator()(const IndexFlagTuple& x) {
|
||||||
|
// the ex_scan_result represents how many rows have been assigned to left
|
||||||
|
// node so far during scan.
|
||||||
|
int scatter_address;
|
||||||
|
if (position_in[x.idx] == left_nidx) {
|
||||||
|
scatter_address = x.flag - 1; // -1 because inclusive scan
|
||||||
|
} else {
|
||||||
|
// current number of rows belong to right node + total number of rows
|
||||||
|
// belong to left node
|
||||||
|
scatter_address = (x.idx - x.flag) + *d_left_count;
|
||||||
|
}
|
||||||
|
// copy the node id to output
|
||||||
|
position_out[scatter_address] = position_in[x.idx];
|
||||||
|
ridx_out[scatter_address] = ridx_in[x.idx];
|
||||||
|
|
||||||
|
// Discard
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
|
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
|
||||||
common::Span<bst_node_t> position_out,
|
common::Span<bst_node_t> position_out,
|
||||||
common::Span<RowIndexT> ridx,
|
common::Span<RowIndexT> ridx,
|
||||||
common::Span<RowIndexT> ridx_out,
|
common::Span<RowIndexT> ridx_out,
|
||||||
bst_node_t left_nidx,
|
bst_node_t left_nidx, bst_node_t right_nidx,
|
||||||
bst_node_t right_nidx,
|
|
||||||
int64_t* d_left_count, cudaStream_t stream) {
|
int64_t* d_left_count, cudaStream_t stream) {
|
||||||
// radix sort over 1 bit, see:
|
WriteResultsFunctor write_results{left_nidx, position, position_out,
|
||||||
// https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch39.html
|
ridx, ridx_out, d_left_count};
|
||||||
auto d_position_out = position_out.data();
|
auto discard_write_iterator = thrust::make_transform_output_iterator(
|
||||||
auto d_position_in = position.data();
|
thrust::discard_iterator<int>(), write_results);
|
||||||
auto d_ridx_out = ridx_out.data();
|
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
|
||||||
auto d_ridx_in = ridx.data();
|
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||||
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
|
return IndexFlagTuple{idx, position[idx] == left_nidx};
|
||||||
// the ex_scan_result represents how many rows have been assigned to left node so far
|
});
|
||||||
// during scan.
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
int scatter_address;
|
thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator,
|
||||||
if (d_position_in[idx] == left_nidx) {
|
input_iterator + position.size(),
|
||||||
scatter_address = ex_scan_result;
|
discard_write_iterator,
|
||||||
} else {
|
[=] __device__(IndexFlagTuple a, IndexFlagTuple b) {
|
||||||
// current number of rows belong to right node + total number of rows belong to left
|
return IndexFlagTuple{b.idx, a.flag + b.flag};
|
||||||
// node
|
});
|
||||||
scatter_address = (idx - ex_scan_result) + *d_left_count;
|
|
||||||
}
|
|
||||||
// copy the node id to output
|
|
||||||
d_position_out[scatter_address] = d_position_in[idx];
|
|
||||||
d_ridx_out[scatter_address] = d_ridx_in[idx];
|
|
||||||
}; // NOLINT
|
|
||||||
|
|
||||||
IndicateLeftTransform is_left(left_nidx);
|
|
||||||
// an iterator that given a old position returns whether it belongs to left or right
|
|
||||||
// node.
|
|
||||||
cub::TransformInputIterator<bst_node_t, IndicateLeftTransform,
|
|
||||||
bst_node_t*>
|
|
||||||
in_itr(d_position_in, is_left);
|
|
||||||
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
|
|
||||||
size_t temp_storage_bytes = 0;
|
|
||||||
// position is of the same size with current split node's row segment
|
|
||||||
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
|
|
||||||
position.size(), stream);
|
|
||||||
dh::caching_device_vector<uint8_t> temp_storage(temp_storage_bytes);
|
|
||||||
cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes,
|
|
||||||
in_itr, out_itr, position.size(), stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
||||||
@ -137,7 +155,7 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
|
|||||||
SortPosition(
|
SortPosition(
|
||||||
// position_in
|
// position_in
|
||||||
common::Span<bst_node_t>(position_.Current() + segment.begin,
|
common::Span<bst_node_t>(position_.Current() + segment.begin,
|
||||||
segment.Size()),
|
segment.Size()),
|
||||||
// position_out
|
// position_out
|
||||||
common::Span<bst_node_t>(position_.Other() + segment.begin,
|
common::Span<bst_node_t>(position_.Other() + segment.begin,
|
||||||
segment.Size()),
|
segment.Size()),
|
||||||
|
|||||||
@ -8,25 +8,6 @@
|
|||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
using xgboost::common::Span;
|
|
||||||
|
|
||||||
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
|
|
||||||
thrust::host_vector<int> *row_ptr,
|
|
||||||
thrust::host_vector<xgboost::bst_uint> *rows) {
|
|
||||||
row_ptr->resize(num_rows + 1);
|
|
||||||
int sum = 0;
|
|
||||||
for (xgboost::bst_uint i = 0; i <= num_rows; i++) {
|
|
||||||
(*row_ptr)[i] = sum;
|
|
||||||
sum += rand() % max_row_size; // NOLINT
|
|
||||||
|
|
||||||
if (i < num_rows) {
|
|
||||||
for (int j = (*row_ptr)[i]; j < sum; j++) {
|
|
||||||
(*rows).push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(SumReduce, Test) {
|
TEST(SumReduce, Test) {
|
||||||
thrust::device_vector<float> data(100, 1.0f);
|
thrust::device_vector<float> data(100, 1.0f);
|
||||||
dh::CubMemory temp;
|
dh::CubMemory temp;
|
||||||
@ -34,80 +15,3 @@ TEST(SumReduce, Test) {
|
|||||||
ASSERT_NEAR(sum, 100.0f, 1e-5);
|
ASSERT_NEAR(sum, 100.0f, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Comp = thrust::less<T>>
|
|
||||||
void TestUpperBoundImpl(const std::vector<T> &vec, T val_to_find,
|
|
||||||
const Comp &comp = Comp()) {
|
|
||||||
EXPECT_EQ(dh::UpperBound(vec.data(), vec.size(), val_to_find, comp),
|
|
||||||
std::upper_bound(vec.begin(), vec.end(), val_to_find, comp) - vec.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Comp = thrust::less<T>>
|
|
||||||
void TestLowerBoundImpl(const std::vector<T> &vec, T val_to_find,
|
|
||||||
const Comp &comp = Comp()) {
|
|
||||||
EXPECT_EQ(dh::LowerBound(vec.data(), vec.size(), val_to_find, comp),
|
|
||||||
std::lower_bound(vec.begin(), vec.end(), val_to_find, comp) - vec.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(UpperBound, DataAscending) {
|
|
||||||
std::vector<int> hvec{0, 3, 5, 5, 7, 8, 9, 10, 10};
|
|
||||||
|
|
||||||
// Test boundary conditions
|
|
||||||
TestUpperBoundImpl(hvec, hvec.front()); // Result 1
|
|
||||||
TestUpperBoundImpl(hvec, hvec.front() - 1); // Result 0
|
|
||||||
TestUpperBoundImpl(hvec, hvec.back() + 1); // Result hvec.size()
|
|
||||||
TestUpperBoundImpl(hvec, hvec.back()); // Result hvec.size()
|
|
||||||
|
|
||||||
// Test other values - both missing and present
|
|
||||||
TestUpperBoundImpl(hvec, 3); // Result 2
|
|
||||||
TestUpperBoundImpl(hvec, 4); // Result 2
|
|
||||||
TestUpperBoundImpl(hvec, 5); // Result 4
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(UpperBound, DataDescending) {
|
|
||||||
std::vector<int> hvec{10, 10, 9, 8, 7, 5, 5, 3, 0, 0};
|
|
||||||
const auto &comparator = thrust::greater<int>();
|
|
||||||
|
|
||||||
// Test boundary conditions
|
|
||||||
TestUpperBoundImpl(hvec, hvec.front(), comparator); // Result 2
|
|
||||||
TestUpperBoundImpl(hvec, hvec.front() + 1, comparator); // Result 0
|
|
||||||
TestUpperBoundImpl(hvec, hvec.back(), comparator); // Result hvec.size()
|
|
||||||
TestUpperBoundImpl(hvec, hvec.back() - 1, comparator); // Result hvec.size()
|
|
||||||
|
|
||||||
// Test other values - both missing and present
|
|
||||||
TestUpperBoundImpl(hvec, 9, comparator); // Result 3
|
|
||||||
TestUpperBoundImpl(hvec, 7, comparator); // Result 5
|
|
||||||
TestUpperBoundImpl(hvec, 4, comparator); // Result 7
|
|
||||||
TestUpperBoundImpl(hvec, 8, comparator); // Result 4
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(LowerBound, DataAscending) {
|
|
||||||
std::vector<int> hvec{0, 3, 5, 5, 7, 8, 9, 10, 10};
|
|
||||||
|
|
||||||
// Test boundary conditions
|
|
||||||
TestLowerBoundImpl(hvec, hvec.front()); // Result 0
|
|
||||||
TestLowerBoundImpl(hvec, hvec.front() - 1); // Result 0
|
|
||||||
TestLowerBoundImpl(hvec, hvec.back()); // Result 7
|
|
||||||
TestLowerBoundImpl(hvec, hvec.back() + 1); // Result hvec.size()
|
|
||||||
|
|
||||||
// Test other values - both missing and present
|
|
||||||
TestLowerBoundImpl(hvec, 3); // Result 1
|
|
||||||
TestLowerBoundImpl(hvec, 4); // Result 2
|
|
||||||
TestLowerBoundImpl(hvec, 5); // Result 2
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(LowerBound, DataDescending) {
|
|
||||||
std::vector<int> hvec{10, 10, 9, 8, 7, 5, 5, 3, 0, 0};
|
|
||||||
const auto &comparator = thrust::greater<int>();
|
|
||||||
|
|
||||||
// Test boundary conditions
|
|
||||||
TestLowerBoundImpl(hvec, hvec.front(), comparator); // Result 0
|
|
||||||
TestLowerBoundImpl(hvec, hvec.front() + 1, comparator); // Result 0
|
|
||||||
TestLowerBoundImpl(hvec, hvec.back(), comparator); // Result 8
|
|
||||||
TestLowerBoundImpl(hvec, hvec.back() - 1, comparator); // Result hvec.size()
|
|
||||||
|
|
||||||
// Test other values - both missing and present
|
|
||||||
TestLowerBoundImpl(hvec, 9, comparator); // Result 2
|
|
||||||
TestLowerBoundImpl(hvec, 7, comparator); // Result 4
|
|
||||||
TestLowerBoundImpl(hvec, 4, comparator); // Result 7
|
|
||||||
TestLowerBoundImpl(hvec, 8, comparator); // Result 3
|
|
||||||
}
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user