Fix deprecated CUB calls in CUDA 12.0 (#8578)
This commit is contained in:
parent
35d8447282
commit
15a88ceef0
@ -1172,17 +1172,32 @@ template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
|
||||
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||
OffsetT num_items) {
|
||||
size_t bytes = 0;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((
|
||||
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
||||
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
|
||||
cub::NullType(), num_items, nullptr)));
|
||||
#else
|
||||
safe_cuda((
|
||||
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
||||
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
|
||||
cub::NullType(), num_items, nullptr,
|
||||
false)));
|
||||
#endif
|
||||
TemporaryArray<char> storage(bytes);
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((
|
||||
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
||||
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
|
||||
d_out, scan_op, cub::NullType(),
|
||||
num_items, nullptr)));
|
||||
#else
|
||||
safe_cuda((
|
||||
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
||||
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
|
||||
d_out, scan_op, cub::NullType(),
|
||||
num_items, nullptr, false)));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename InIt, typename OutIt, typename Predicate>
|
||||
@ -1225,24 +1240,48 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
|
||||
if (accending) {
|
||||
void *d_temp_storage = nullptr;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
} else {
|
||||
void *d_temp_storage = nullptr;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
}
|
||||
|
||||
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
||||
@ -1269,7 +1308,14 @@ void DeviceSegmentedRadixSortPair(
|
||||
CHECK_LE(num_items, std::numeric_limits<OffsetT>::max());
|
||||
// For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation
|
||||
|
||||
#if (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13) || THRUST_MAJOR_VERSION > 1
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((cub::DispatchSegmentedRadixSort<
|
||||
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
|
||||
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
|
||||
d_values, num_items, num_segments,
|
||||
d_begin_offsets, d_end_offsets, begin_bit,
|
||||
end_bit, false, nullptr)));
|
||||
#elif (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13)
|
||||
safe_cuda((cub::DispatchSegmentedRadixSort<
|
||||
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
|
||||
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
|
||||
|
||||
@ -238,13 +238,25 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
using DispatchScan =
|
||||
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
nullptr);
|
||||
#else
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
nullptr, false);
|
||||
#endif
|
||||
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr);
|
||||
#else
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user