fix DispatchRadixSort
This commit is contained in:
parent
a45005863b
commit
cd743a1ae9
@ -1238,6 +1238,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
|||||||
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
|
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
|
||||||
if (accending) {
|
if (accending) {
|
||||||
void *d_temp_storage = nullptr;
|
void *d_temp_storage = nullptr;
|
||||||
|
|
||||||
|
#if 0
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
@ -1247,8 +1249,16 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
|||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||||
|
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
|
sizeof(KeyT) * 8)));
|
||||||
|
|
||||||
TemporaryArray<char> storage(bytes);
|
TemporaryArray<char> storage(bytes);
|
||||||
d_temp_storage = storage.data().get();
|
d_temp_storage = storage.data().get();
|
||||||
|
|
||||||
|
#if 0
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
@ -1258,8 +1268,15 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
|||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||||
|
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
|
sizeof(KeyT) * 8)));
|
||||||
} else {
|
} else {
|
||||||
void *d_temp_storage = nullptr;
|
void *d_temp_storage = nullptr;
|
||||||
|
|
||||||
|
#if 0
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
@ -1269,8 +1286,16 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
|||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||||
|
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
|
sizeof(KeyT) * 8)));
|
||||||
|
|
||||||
TemporaryArray<char> storage(bytes);
|
TemporaryArray<char> storage(bytes);
|
||||||
d_temp_storage = storage.data().get();
|
d_temp_storage = storage.data().get();
|
||||||
|
|
||||||
|
#if 0
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
@ -1280,6 +1305,10 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
|||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||||
|
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
|
sizeof(KeyT) * 8)));
|
||||||
}
|
}
|
||||||
|
|
||||||
safe_cuda(hipMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
safe_cuda(hipMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
||||||
@ -1355,7 +1384,7 @@ class LDGIterator {
|
|||||||
__device__ T operator[](std::size_t idx) const {
|
__device__ T operator[](std::size_t idx) const {
|
||||||
DeviceWordT tmp[kNumWords];
|
DeviceWordT tmp[kNumWords];
|
||||||
static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal.");
|
static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal.");
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kNumWords; i++) {
|
for (int i = 0; i < kNumWords; i++) {
|
||||||
tmp[i] = __ldg(reinterpret_cast<const DeviceWordT *>(ptr_ + idx) + i);
|
tmp[i] = __ldg(reinterpret_cast<const DeviceWordT *>(ptr_ + idx) + i);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user