fix DispatchRadixSort

This commit is contained in:
amdsc21 2023-03-08 20:31:23 +01:00
parent a45005863b
commit cd743a1ae9

View File

@ -1238,6 +1238,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
if (accending) {
void *d_temp_storage = nullptr;
#if 0
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
@ -1247,8 +1249,16 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
#endif
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if 0
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
@ -1258,8 +1268,15 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
#endif
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
} else {
void *d_temp_storage = nullptr;
#if 0
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
@ -1269,8 +1286,16 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
#endif
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if 0
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
@ -1280,6 +1305,10 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
#endif
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
}
safe_cuda(hipMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
@ -1355,7 +1384,7 @@ class LDGIterator {
__device__ T operator[](std::size_t idx) const {
DeviceWordT tmp[kNumWords];
static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal.");
#pragma unroll
for (int i = 0; i < kNumWords; i++) {
tmp[i] = __ldg(reinterpret_cast<const DeviceWordT *>(ptr_ + idx) + i);
}