Add device argsort. (#6749)
This is part of https://github.com/dmlc/xgboost/pull/6747 .
This commit is contained in:
parent
325bc93e16
commit
1a73a28511
@ -95,7 +95,7 @@ namespace common {
|
||||
#define KERNEL_CHECK(cond) \
|
||||
(XGBOOST_EXPECT((cond), true) \
|
||||
? static_cast<void>(0) \
|
||||
: __assert_fail(__ASSERT_STR_HELPER(e), __FILE__, __LINE__, \
|
||||
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, \
|
||||
__PRETTY_FUNCTION__))
|
||||
|
||||
#endif // defined(_MSC_VER)
|
||||
|
||||
@ -295,6 +295,11 @@ inline void LaunchN(int device_idx, size_t n, L lambda) {
|
||||
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(device_idx, n, nullptr, lambda);
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
void Iota(Container array, int32_t device = CurrentDevice()) {
|
||||
LaunchN(device, array.size(), [=] __device__(size_t i) { array[i] = i; });
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
/** \brief Keeps track of global device memory allocations. Thread safe.*/
|
||||
class MemoryLogger {
|
||||
@ -1179,4 +1184,72 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce
|
||||
}
|
||||
return aggregate;
|
||||
}
|
||||
|
||||
template <bool accending, typename IdxT, typename U>
|
||||
void ArgSort(xgboost::common::Span<U> values, xgboost::common::Span<IdxT> sorted_idx) {
|
||||
size_t bytes = 0;
|
||||
Iota(sorted_idx);
|
||||
CHECK_LT(sorted_idx.size(), 1 << 31);
|
||||
TemporaryArray<U> out(values.size());
|
||||
if (accending) {
|
||||
cub::DeviceRadixSort::SortPairs(nullptr, bytes, values.data(),
|
||||
out.data().get(), sorted_idx.data(),
|
||||
sorted_idx.data(), sorted_idx.size());
|
||||
dh::TemporaryArray<char> storage(bytes);
|
||||
cub::DeviceRadixSort::SortPairs(storage.data().get(), bytes, values.data(),
|
||||
out.data().get(), sorted_idx.data(),
|
||||
sorted_idx.data(), sorted_idx.size());
|
||||
} else {
|
||||
cub::DeviceRadixSort::SortPairsDescending(
|
||||
nullptr, bytes, values.data(), out.data().get(), sorted_idx.data(),
|
||||
sorted_idx.data(), sorted_idx.size());
|
||||
dh::TemporaryArray<char> storage(bytes);
|
||||
cub::DeviceRadixSort::SortPairsDescending(
|
||||
storage.data().get(), bytes, values.data(), out.data().get(),
|
||||
sorted_idx.data(), sorted_idx.data(), sorted_idx.size());
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
// Wrapper around cub sort for easier `descending` sort
|
||||
template <bool descending, typename KeyT, typename ValueT, typename OffsetIteratorT>
|
||||
void DeviceSegmentedRadixSortPair(
|
||||
void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT
|
||||
KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out,
|
||||
size_t num_items, size_t num_segments, OffsetIteratorT d_begin_offsets,
|
||||
OffsetIteratorT d_end_offsets, int begin_bit = 0,
|
||||
int end_bit = sizeof(KeyT) * 8) {
|
||||
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
|
||||
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in),
|
||||
d_values_out);
|
||||
using OffsetT = size_t;
|
||||
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
|
||||
descending, KeyT, ValueT, OffsetIteratorT,
|
||||
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
|
||||
d_values, num_items, num_segments,
|
||||
d_begin_offsets, d_end_offsets, begin_bit,
|
||||
end_bit, false, nullptr, false)));
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template <bool accending, typename U, typename V, typename IdxT>
|
||||
void SegmentedArgSort(xgboost::common::Span<U> values,
|
||||
xgboost::common::Span<V> group_ptr,
|
||||
xgboost::common::Span<IdxT> sorted_idx) {
|
||||
CHECK_GE(group_ptr.size(), 1ul);
|
||||
size_t n_groups = group_ptr.size() - 1;
|
||||
size_t bytes = 0;
|
||||
Iota(sorted_idx);
|
||||
CHECK_LT(sorted_idx.size(), 1 << 31);
|
||||
TemporaryArray<U> values_out(values.size());
|
||||
detail::DeviceSegmentedRadixSortPair<!accending>(
|
||||
nullptr, bytes, values.data(), values_out.data().get(),
|
||||
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups,
|
||||
group_ptr.data(), group_ptr.data() + 1);
|
||||
dh::TemporaryArray<xgboost::common::byte> temp_storage(bytes);
|
||||
detail::DeviceSegmentedRadixSortPair<!accending>(
|
||||
temp_storage.data().get(), bytes, values.data(), values_out.data().get(),
|
||||
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups,
|
||||
group_ptr.data(), group_ptr.data() + 1);
|
||||
}
|
||||
} // namespace dh
|
||||
|
||||
@ -171,5 +171,28 @@ TEST(Allocator, OOM) {
|
||||
// Clear last error so we don't fail subsequent tests
|
||||
cudaGetLastError();
|
||||
}
|
||||
|
||||
TEST(DeviceHelpers, ArgSort) {
|
||||
dh::device_vector<float> values(20);
|
||||
dh::Iota(dh::ToSpan(values)); // accending
|
||||
dh::device_vector<size_t> sorted_idx(20);
|
||||
dh::ArgSort<false>(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
|
||||
ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
|
||||
sorted_idx.end(), thrust::greater<size_t>{}));
|
||||
|
||||
dh::Iota(dh::ToSpan(values));
|
||||
dh::device_vector<size_t> groups(3);
|
||||
groups[0] = 0;
|
||||
groups[1] = 10;
|
||||
groups[2] = 20;
|
||||
dh::SegmentedArgSort<false>(dh::ToSpan(values), dh::ToSpan(groups),
|
||||
dh::ToSpan(sorted_idx));
|
||||
ASSERT_FALSE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
|
||||
sorted_idx.end(), thrust::greater<size_t>{}));
|
||||
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin(), sorted_idx.begin() + 10,
|
||||
thrust::greater<size_t>{}));
|
||||
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
|
||||
thrust::greater<size_t>{}));
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user