Fix deprecated CUB calls in CUDA 12.0 (#8578)

This commit is contained in:
Rong Ou
2022-12-12 01:02:30 -08:00
committed by GitHub
parent 35d8447282
commit 15a88ceef0
2 changed files with 59 additions and 1 deletions

View File

@@ -238,13 +238,25 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
using DispatchScan =
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
TupleScanOp<Tuple>, cub::NullType, int64_t>;
#if THRUST_MAJOR_VERSION >= 2
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr);
#else
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr, false);
#endif
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
#if THRUST_MAJOR_VERSION >= 2
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr);
#else
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr, false);
#endif
}
void WriteNullValues(EllpackPageImpl* dst, int device_idx,