Fix deprecated CUB calls in CUDA 12.0 (#8578)

This commit is contained in:
Rong Ou 2022-12-12 01:02:30 -08:00 committed by GitHub
parent 35d8447282
commit 15a88ceef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 1 deletions

View File

@ -1172,17 +1172,32 @@ template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op, void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
OffsetT num_items) { OffsetT num_items) {
size_t bytes = 0; size_t bytes = 0;
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
cub::NullType(), num_items, nullptr)));
#else
safe_cuda(( safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op, OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
cub::NullType(), num_items, nullptr, cub::NullType(), num_items, nullptr,
false))); false)));
#endif
TemporaryArray<char> storage(bytes); TemporaryArray<char> storage(bytes);
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
d_out, scan_op, cub::NullType(),
num_items, nullptr)));
#else
safe_cuda(( safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(storage.data().get(), bytes, d_in, OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
d_out, scan_op, cub::NullType(), d_out, scan_op, cub::NullType(),
num_items, nullptr, false))); num_items, nullptr, false)));
#endif
} }
template <typename InIt, typename OutIt, typename Predicate> template <typename InIt, typename OutIt, typename Predicate>
@ -1225,24 +1240,48 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max()); CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
if (accending) { if (accending) {
void *d_temp_storage = nullptr; void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch( safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false))); sizeof(KeyT) * 8, false, nullptr, false)));
#endif
TemporaryArray<char> storage(bytes); TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get(); d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch( safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false))); sizeof(KeyT) * 8, false, nullptr, false)));
#endif
} else { } else {
void *d_temp_storage = nullptr; void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch( safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false))); sizeof(KeyT) * 8, false, nullptr, false)));
#endif
TemporaryArray<char> storage(bytes); TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get(); d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch( safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false))); sizeof(KeyT) * 8, false, nullptr, false)));
#endif
} }
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
@ -1269,7 +1308,14 @@ void DeviceSegmentedRadixSortPair(
CHECK_LE(num_items, std::numeric_limits<OffsetT>::max()); CHECK_LE(num_items, std::numeric_limits<OffsetT>::max());
// For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation // For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation
#if (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13) || THRUST_MAJOR_VERSION > 1 #if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
d_values, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, nullptr)));
#elif (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13)
safe_cuda((cub::DispatchSegmentedRadixSort< safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,

View File

@ -238,13 +238,25 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
using DispatchScan = using DispatchScan =
cub::DispatchScan<decltype(key_value_index_iter), decltype(out), cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
TupleScanOp<Tuple>, cub::NullType, int64_t>; TupleScanOp<Tuple>, cub::NullType, int64_t>;
#if THRUST_MAJOR_VERSION >= 2
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr);
#else
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out, DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(), TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr, false); nullptr, false);
#endif
dh::TemporaryArray<char> temp_storage(temp_storage_bytes); dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
#if THRUST_MAJOR_VERSION >= 2
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr);
#else
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes, DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(), key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr, false); cub::NullType(), batch.Size(), nullptr, false);
#endif
} }
void WriteNullValues(EllpackPageImpl* dst, int device_idx, void WriteNullValues(EllpackPageImpl* dst, int device_idx,