fix DispatchRadixSort
This commit is contained in:
parent
a45005863b
commit
cd743a1ae9
@ -1238,6 +1238,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
|
||||
if (accending) {
|
||||
void *d_temp_storage = nullptr;
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
@ -1247,8 +1249,16 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
|
||||
TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
@ -1258,8 +1268,15 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
} else {
|
||||
void *d_temp_storage = nullptr;
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
@ -1269,8 +1286,16 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
|
||||
TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
@ -1280,6 +1305,10 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
}
|
||||
|
||||
safe_cuda(hipMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
||||
@ -1355,7 +1384,7 @@ class LDGIterator {
|
||||
__device__ T operator[](std::size_t idx) const {
|
||||
DeviceWordT tmp[kNumWords];
|
||||
static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal.");
|
||||
#pragma unroll
|
||||
|
||||
for (int i = 0; i < kNumWords; i++) {
|
||||
tmp[i] = __ldg(reinterpret_cast<const DeviceWordT *>(ptr_ + idx) + i);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user