From ca6fcd361ee7e957c33b45446e2708b666e078f7 Mon Sep 17 00:00:00 2001 From: Hendrik Groove Date: Mon, 21 Oct 2024 00:13:27 +0200 Subject: [PATCH] fix --- src/common/device_helpers.hip.h | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/common/device_helpers.hip.h b/src/common/device_helpers.hip.h index 5fda1ce47..94db633fe 100644 --- a/src/common/device_helpers.hip.h +++ b/src/common/device_helpers.hip.h @@ -987,16 +987,24 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce using Ty = std::remove_cv_t; Ty aggregate = init; - // Get the HIP stream from the policy - hipStream_t stream = thrust::hip::stream(policy); - std::cerr << "HIP stream: " << stream << std::endl; + // Try to get the HIP stream from the policy + hipStream_t stream = nullptr; + try { + stream = policy.stream(); + std::cerr << "HIP stream from policy: " << stream << std::endl; + } catch (const std::exception& e) { + std::cerr << "Unable to get stream from policy: " << e.what() << std::endl; + std::cerr << "Using default stream" << std::endl; + } - // Check stream validity - hipError_t stream_err = hipStreamQuery(stream); - if (stream_err != hipSuccess && stream_err != hipErrorNotReady) { - std::cerr << "Invalid stream: " << hipGetErrorString(stream_err) << std::endl; - } else { - std::cerr << "Stream is valid" << std::endl; + // Check stream validity if we got a stream + if (stream != nullptr) { + hipError_t stream_err = hipStreamQuery(stream); + if (stream_err != hipSuccess && stream_err != hipErrorNotReady) { + std::cerr << "Invalid stream: " << hipGetErrorString(stream_err) << std::endl; + } else { + std::cerr << "Stream is valid" << std::endl; + } } for (size_t offset = 0; offset < size; offset += kLimit) {