diff --git a/src/common/device_helpers.hip.h b/src/common/device_helpers.hip.h index 94db633fe..bfeffe7c5 100644 --- a/src/common/device_helpers.hip.h +++ b/src/common/device_helpers.hip.h @@ -987,17 +987,17 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce using Ty = std::remove_cv_t; Ty aggregate = init; - // Try to get the HIP stream from the policy + // Get the stream from the policy hipStream_t stream = nullptr; try { stream = policy.stream(); - std::cerr << "HIP stream from policy: " << stream << std::endl; + std::cerr << "HIP stream from policy: " << static_cast(stream) << std::endl; } catch (const std::exception& e) { std::cerr << "Unable to get stream from policy: " << e.what() << std::endl; std::cerr << "Using default stream" << std::endl; } - // Check stream validity if we got a stream + // Check stream validity if (stream != nullptr) { hipError_t stream_err = hipStreamQuery(stream); if (stream_err != hipSuccess && stream_err != hipErrorNotReady) { @@ -1019,7 +1019,8 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce // Get the raw pointers for debugging auto raw_begin = thrust::raw_pointer_cast(&*begin_it); auto raw_end = thrust::raw_pointer_cast(&*end_it); - std::cerr << "Raw pointers - begin: " << raw_begin << ", end: " << raw_end << std::endl; + std::cerr << "Raw pointers - begin: " << static_cast(raw_begin) + << ", end: " << static_cast(raw_end) << std::endl; // Check if the pointers are valid device pointers hipPointerAttribute_t attrs; @@ -1034,15 +1035,20 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce aggregate = reduce_op(aggregate, ret); std::cerr << "Batch reduction completed successfully" << std::endl; + } catch (const thrust::system_error& e) { + std::cerr << "Thrust system error in reduce: " << e.what() << std::endl; + std::cerr << "Error code: " << e.code() << std::endl; + throw; } catch (const std::exception& e) { std::cerr << "Exception in thrust::reduce: " << e.what() << std::endl; - - // Get the last HIP error - hipError_t last_error = hipGetLastError(); - std::cerr << "Last HIP error: " << hipGetErrorString(last_error) << std::endl; - throw; } + + // Check for any HIP errors after the reduction + hipError_t hip_err = hipGetLastError(); + if (hip_err != hipSuccess) { + std::cerr << "HIP error after reduction: " << hipGetErrorString(hip_err) << std::endl; + } } std::cerr << "Exiting Reduce function" << std::endl;