SumReduction logging
This commit is contained in:
parent
bf2ef6c586
commit
db66fad9e9
@ -665,21 +665,31 @@ struct PinnedMemory {
|
|||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
typename std::iterator_traits<T>::value_type SumReduction(T in, int nVals) {
|
typename std::iterator_traits<T>::value_type SumReduction(T in, int nVals) {
|
||||||
|
std::cerr << "Entering SumReduction, nVals: " << nVals << std::endl;
|
||||||
using ValueT = typename std::iterator_traits<T>::value_type;
|
using ValueT = typename std::iterator_traits<T>::value_type;
|
||||||
|
|
||||||
size_t tmpSize {0};
|
size_t tmpSize {0};
|
||||||
ValueT *dummy_out = nullptr;
|
ValueT *dummy_out = nullptr;
|
||||||
|
|
||||||
|
try {
|
||||||
dh::safe_cuda(hipcub::DeviceReduce::Sum(nullptr, tmpSize, in, dummy_out, nVals));
|
dh::safe_cuda(hipcub::DeviceReduce::Sum(nullptr, tmpSize, in, dummy_out, nVals));
|
||||||
|
std::cerr << "Temporary storage size: " << tmpSize << std::endl;
|
||||||
|
|
||||||
TemporaryArray<char> temp(tmpSize + sizeof(ValueT));
|
TemporaryArray<char> temp(tmpSize + sizeof(ValueT));
|
||||||
auto ptr = reinterpret_cast<ValueT *>(temp.data().get()) + 1;
|
auto ptr = reinterpret_cast<ValueT *>(temp.data().get()) + 1;
|
||||||
|
|
||||||
dh::safe_cuda(hipcub::DeviceReduce::Sum(
|
dh::safe_cuda(hipcub::DeviceReduce::Sum(
|
||||||
reinterpret_cast<void *>(ptr), tmpSize, in,
|
reinterpret_cast<void *>(ptr), tmpSize, in, reinterpret_cast<ValueT *>(temp.data().get()), nVals));
|
||||||
reinterpret_cast<ValueT *>(temp.data().get()),
|
|
||||||
nVals));
|
|
||||||
ValueT sum;
|
ValueT sum;
|
||||||
dh::safe_cuda(hipMemcpy(&sum, temp.data().get(), sizeof(ValueT),
|
dh::safe_cuda(hipMemcpy(&sum, temp.data().get(), sizeof(ValueT), hipMemcpyDeviceToHost));
|
||||||
hipMemcpyDeviceToHost));
|
|
||||||
|
std::cerr << "SumReduction completed successfully" << std::endl;
|
||||||
return sum;
|
return sum;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::cerr << "Exception in SumReduction: " << e.what() << std::endl;
|
||||||
|
throw;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr std::pair<int, int> CUDAVersion() {
|
constexpr std::pair<int, int> CUDAVersion() {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user