Add device argsort. (#6749)
This is part of https://github.com/dmlc/xgboost/pull/6747 .
This commit is contained in:
parent
325bc93e16
commit
1a73a28511
@ -95,7 +95,7 @@ namespace common {
|
|||||||
#define KERNEL_CHECK(cond) \
|
#define KERNEL_CHECK(cond) \
|
||||||
(XGBOOST_EXPECT((cond), true) \
|
(XGBOOST_EXPECT((cond), true) \
|
||||||
? static_cast<void>(0) \
|
? static_cast<void>(0) \
|
||||||
: __assert_fail(__ASSERT_STR_HELPER(e), __FILE__, __LINE__, \
|
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, \
|
||||||
__PRETTY_FUNCTION__))
|
__PRETTY_FUNCTION__))
|
||||||
|
|
||||||
#endif // defined(_MSC_VER)
|
#endif // defined(_MSC_VER)
|
||||||
|
|||||||
@ -295,6 +295,11 @@ inline void LaunchN(int device_idx, size_t n, L lambda) {
|
|||||||
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(device_idx, n, nullptr, lambda);
|
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(device_idx, n, nullptr, lambda);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Container>
|
||||||
|
void Iota(Container array, int32_t device = CurrentDevice()) {
|
||||||
|
LaunchN(device, array.size(), [=] __device__(size_t i) { array[i] = i; });
|
||||||
|
}
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
/** \brief Keeps track of global device memory allocations. Thread safe.*/
|
/** \brief Keeps track of global device memory allocations. Thread safe.*/
|
||||||
class MemoryLogger {
|
class MemoryLogger {
|
||||||
@ -1179,4 +1184,72 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce
|
|||||||
}
|
}
|
||||||
return aggregate;
|
return aggregate;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool accending, typename IdxT, typename U>
|
||||||
|
void ArgSort(xgboost::common::Span<U> values, xgboost::common::Span<IdxT> sorted_idx) {
|
||||||
|
size_t bytes = 0;
|
||||||
|
Iota(sorted_idx);
|
||||||
|
CHECK_LT(sorted_idx.size(), 1 << 31);
|
||||||
|
TemporaryArray<U> out(values.size());
|
||||||
|
if (accending) {
|
||||||
|
cub::DeviceRadixSort::SortPairs(nullptr, bytes, values.data(),
|
||||||
|
out.data().get(), sorted_idx.data(),
|
||||||
|
sorted_idx.data(), sorted_idx.size());
|
||||||
|
dh::TemporaryArray<char> storage(bytes);
|
||||||
|
cub::DeviceRadixSort::SortPairs(storage.data().get(), bytes, values.data(),
|
||||||
|
out.data().get(), sorted_idx.data(),
|
||||||
|
sorted_idx.data(), sorted_idx.size());
|
||||||
|
} else {
|
||||||
|
cub::DeviceRadixSort::SortPairsDescending(
|
||||||
|
nullptr, bytes, values.data(), out.data().get(), sorted_idx.data(),
|
||||||
|
sorted_idx.data(), sorted_idx.size());
|
||||||
|
dh::TemporaryArray<char> storage(bytes);
|
||||||
|
cub::DeviceRadixSort::SortPairsDescending(
|
||||||
|
storage.data().get(), bytes, values.data(), out.data().get(),
|
||||||
|
sorted_idx.data(), sorted_idx.data(), sorted_idx.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
// Wrapper around cub sort for easier `descending` sort
|
||||||
|
template <bool descending, typename KeyT, typename ValueT, typename OffsetIteratorT>
|
||||||
|
void DeviceSegmentedRadixSortPair(
|
||||||
|
void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT
|
||||||
|
KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out,
|
||||||
|
size_t num_items, size_t num_segments, OffsetIteratorT d_begin_offsets,
|
||||||
|
OffsetIteratorT d_end_offsets, int begin_bit = 0,
|
||||||
|
int end_bit = sizeof(KeyT) * 8) {
|
||||||
|
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
|
||||||
|
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in),
|
||||||
|
d_values_out);
|
||||||
|
using OffsetT = size_t;
|
||||||
|
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
|
||||||
|
descending, KeyT, ValueT, OffsetIteratorT,
|
||||||
|
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
|
||||||
|
d_values, num_items, num_segments,
|
||||||
|
d_begin_offsets, d_end_offsets, begin_bit,
|
||||||
|
end_bit, false, nullptr, false)));
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <bool accending, typename U, typename V, typename IdxT>
|
||||||
|
void SegmentedArgSort(xgboost::common::Span<U> values,
|
||||||
|
xgboost::common::Span<V> group_ptr,
|
||||||
|
xgboost::common::Span<IdxT> sorted_idx) {
|
||||||
|
CHECK_GE(group_ptr.size(), 1ul);
|
||||||
|
size_t n_groups = group_ptr.size() - 1;
|
||||||
|
size_t bytes = 0;
|
||||||
|
Iota(sorted_idx);
|
||||||
|
CHECK_LT(sorted_idx.size(), 1 << 31);
|
||||||
|
TemporaryArray<U> values_out(values.size());
|
||||||
|
detail::DeviceSegmentedRadixSortPair<!accending>(
|
||||||
|
nullptr, bytes, values.data(), values_out.data().get(),
|
||||||
|
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups,
|
||||||
|
group_ptr.data(), group_ptr.data() + 1);
|
||||||
|
dh::TemporaryArray<xgboost::common::byte> temp_storage(bytes);
|
||||||
|
detail::DeviceSegmentedRadixSortPair<!accending>(
|
||||||
|
temp_storage.data().get(), bytes, values.data(), values_out.data().get(),
|
||||||
|
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups,
|
||||||
|
group_ptr.data(), group_ptr.data() + 1);
|
||||||
|
}
|
||||||
} // namespace dh
|
} // namespace dh
|
||||||
|
|||||||
@ -171,5 +171,28 @@ TEST(Allocator, OOM) {
|
|||||||
// Clear last error so we don't fail subsequent tests
|
// Clear last error so we don't fail subsequent tests
|
||||||
cudaGetLastError();
|
cudaGetLastError();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(DeviceHelpers, ArgSort) {
|
||||||
|
dh::device_vector<float> values(20);
|
||||||
|
dh::Iota(dh::ToSpan(values)); // accending
|
||||||
|
dh::device_vector<size_t> sorted_idx(20);
|
||||||
|
dh::ArgSort<false>(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
|
||||||
|
ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
|
||||||
|
sorted_idx.end(), thrust::greater<size_t>{}));
|
||||||
|
|
||||||
|
dh::Iota(dh::ToSpan(values));
|
||||||
|
dh::device_vector<size_t> groups(3);
|
||||||
|
groups[0] = 0;
|
||||||
|
groups[1] = 10;
|
||||||
|
groups[2] = 20;
|
||||||
|
dh::SegmentedArgSort<false>(dh::ToSpan(values), dh::ToSpan(groups),
|
||||||
|
dh::ToSpan(sorted_idx));
|
||||||
|
ASSERT_FALSE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
|
||||||
|
sorted_idx.end(), thrust::greater<size_t>{}));
|
||||||
|
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin(), sorted_idx.begin() + 10,
|
||||||
|
thrust::greater<size_t>{}));
|
||||||
|
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
|
||||||
|
thrust::greater<size_t>{}));
|
||||||
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user