Fix deprecated CUB calls in CUDA 12.0 (#8578)
This commit is contained in:
parent
35d8447282
commit
15a88ceef0
@ -1172,17 +1172,32 @@ template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
|
|||||||
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||||
OffsetT num_items) {
|
OffsetT num_items) {
|
||||||
size_t bytes = 0;
|
size_t bytes = 0;
|
||||||
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
|
safe_cuda((
|
||||||
|
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
||||||
|
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
|
||||||
|
cub::NullType(), num_items, nullptr)));
|
||||||
|
#else
|
||||||
safe_cuda((
|
safe_cuda((
|
||||||
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
||||||
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
|
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
|
||||||
cub::NullType(), num_items, nullptr,
|
cub::NullType(), num_items, nullptr,
|
||||||
false)));
|
false)));
|
||||||
|
#endif
|
||||||
TemporaryArray<char> storage(bytes);
|
TemporaryArray<char> storage(bytes);
|
||||||
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
|
safe_cuda((
|
||||||
|
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
||||||
|
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
|
||||||
|
d_out, scan_op, cub::NullType(),
|
||||||
|
num_items, nullptr)));
|
||||||
|
#else
|
||||||
safe_cuda((
|
safe_cuda((
|
||||||
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
|
||||||
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
|
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
|
||||||
d_out, scan_op, cub::NullType(),
|
d_out, scan_op, cub::NullType(),
|
||||||
num_items, nullptr, false)));
|
num_items, nullptr, false)));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename InIt, typename OutIt, typename Predicate>
|
template <typename InIt, typename OutIt, typename Predicate>
|
||||||
@ -1225,24 +1240,48 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
|||||||
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
|
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
|
||||||
if (accending) {
|
if (accending) {
|
||||||
void *d_temp_storage = nullptr;
|
void *d_temp_storage = nullptr;
|
||||||
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
|
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
|
sizeof(KeyT) * 8, false, nullptr)));
|
||||||
|
#else
|
||||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||||
|
#endif
|
||||||
TemporaryArray<char> storage(bytes);
|
TemporaryArray<char> storage(bytes);
|
||||||
d_temp_storage = storage.data().get();
|
d_temp_storage = storage.data().get();
|
||||||
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
|
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
|
sizeof(KeyT) * 8, false, nullptr)));
|
||||||
|
#else
|
||||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
void *d_temp_storage = nullptr;
|
void *d_temp_storage = nullptr;
|
||||||
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
|
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
|
sizeof(KeyT) * 8, false, nullptr)));
|
||||||
|
#else
|
||||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||||
|
#endif
|
||||||
TemporaryArray<char> storage(bytes);
|
TemporaryArray<char> storage(bytes);
|
||||||
d_temp_storage = storage.data().get();
|
d_temp_storage = storage.data().get();
|
||||||
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
|
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
|
sizeof(KeyT) * 8, false, nullptr)));
|
||||||
|
#else
|
||||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
||||||
@ -1269,7 +1308,14 @@ void DeviceSegmentedRadixSortPair(
|
|||||||
CHECK_LE(num_items, std::numeric_limits<OffsetT>::max());
|
CHECK_LE(num_items, std::numeric_limits<OffsetT>::max());
|
||||||
// For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation
|
// For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation
|
||||||
|
|
||||||
#if (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13) || THRUST_MAJOR_VERSION > 1
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
|
safe_cuda((cub::DispatchSegmentedRadixSort<
|
||||||
|
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
|
||||||
|
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
|
||||||
|
d_values, num_items, num_segments,
|
||||||
|
d_begin_offsets, d_end_offsets, begin_bit,
|
||||||
|
end_bit, false, nullptr)));
|
||||||
|
#elif (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13)
|
||||||
safe_cuda((cub::DispatchSegmentedRadixSort<
|
safe_cuda((cub::DispatchSegmentedRadixSort<
|
||||||
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
|
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
|
||||||
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
|
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
|
||||||
|
|||||||
@ -238,13 +238,25 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
|||||||
using DispatchScan =
|
using DispatchScan =
|
||||||
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||||
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
||||||
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
|
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||||
|
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||||
|
nullptr);
|
||||||
|
#else
|
||||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||||
nullptr, false);
|
nullptr, false);
|
||||||
|
#endif
|
||||||
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
|
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
|
||||||
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
|
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||||
|
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||||
|
cub::NullType(), batch.Size(), nullptr);
|
||||||
|
#else
|
||||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||||
cub::NullType(), batch.Size(), nullptr, false);
|
cub::NullType(), batch.Size(), nullptr, false);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user