From cd743a1ae9b80a1d37518eaa0ac3a7bef0b8b8fd Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 20:31:23 +0100 Subject: [PATCH] fix DispatchRadixSort --- src/common/device_helpers.hip.h | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/src/common/device_helpers.hip.h b/src/common/device_helpers.hip.h index 3ac3f6b6a..2044f985a 100644 --- a/src/common/device_helpers.hip.h +++ b/src/common/device_helpers.hip.h @@ -1238,6 +1238,8 @@ void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_i CHECK_LE(sorted_idx.size(), std::numeric_limits::max()); if (accending) { void *d_temp_storage = nullptr; + +#if 0 #if THRUST_MAJOR_VERSION >= 2 safe_cuda((hipcub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, @@ -1247,8 +1249,16 @@ void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_i d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, nullptr, false))); #endif +#endif + + safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, + bytes, d_keys, d_values, sorted_idx.size(), 0, + sizeof(KeyT) * 8))); + TemporaryArray storage(bytes); d_temp_storage = storage.data().get(); + +#if 0 #if THRUST_MAJOR_VERSION >= 2 safe_cuda((hipcub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, @@ -1258,8 +1268,15 @@ void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_i d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, nullptr, false))); #endif +#endif + + safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, + bytes, d_keys, d_values, sorted_idx.size(), 0, + sizeof(KeyT) * 8))); } else { void *d_temp_storage = nullptr; + +#if 0 #if THRUST_MAJOR_VERSION >= 2 safe_cuda((hipcub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, @@ -1269,8 +1286,16 @@ void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_i d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, nullptr, false))); #endif +#endif + + safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, + bytes, d_keys, d_values, sorted_idx.size(), 0, + sizeof(KeyT) * 8))); + TemporaryArray storage(bytes); d_temp_storage = storage.data().get(); + +#if 0 #if THRUST_MAJOR_VERSION >= 2 safe_cuda((hipcub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, @@ -1280,6 +1305,10 @@ void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_i d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, nullptr, false))); #endif +#endif + safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, + bytes, d_keys, d_values, sorted_idx.size(), 0, + sizeof(KeyT) * 8))); } safe_cuda(hipMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), @@ -1355,7 +1384,7 @@ class LDGIterator { __device__ T operator[](std::size_t idx) const { DeviceWordT tmp[kNumWords]; static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal."); -#pragma unroll + for (int i = 0; i < kNumWords; i++) { tmp[i] = __ldg(reinterpret_cast(ptr_ + idx) + i); }