Fix deprecated CUB calls in CUDA 12.0 (#8578)
This commit is contained in:
@@ -238,13 +238,25 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
using DispatchScan =
|
||||
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
nullptr);
|
||||
#else
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
nullptr, false);
|
||||
#endif
|
||||
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr);
|
||||
#else
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
|
||||
Reference in New Issue
Block a user