From 1b6c6baf760f05a839611d4cd8665552bee63404 Mon Sep 17 00:00:00 2001 From: Hendrik Groove Date: Mon, 21 Oct 2024 01:39:43 +0200 Subject: [PATCH] restore stream logic --- src/data/array_interface.cu | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/data/array_interface.cu b/src/data/array_interface.cu index 3aabdd9ba..11baf4c1f 100644 --- a/src/data/array_interface.cu +++ b/src/data/array_interface.cu @@ -11,6 +11,33 @@ namespace xgboost { void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) { +#if defined(XGBOOST_USE_HIP) + if (stream == 1 || stream == 2) { + // Default streams, no synchronization needed + return; + } + + hipEvent_t event; + hipError_t err = hipEventCreate(&event); + if (err != hipSuccess) { + LOG(WARNING) << "Failed to create HIP event: " << hipGetErrorString(err); + return; + } + + err = hipEventRecord(event, reinterpret_cast(stream)); + if (err != hipSuccess) { + LOG(WARNING) << "Failed to record HIP event: " << hipGetErrorString(err); + hipEventDestroy(event); + return; + } + + err = hipStreamWaitEvent(hipStreamPerThread, event, 0); + if (err != hipSuccess) { + LOG(WARNING) << "Failed to wait for HIP event: " << hipGetErrorString(err); + } + + hipEventDestroy(event); +#else switch (stream) { case 0: /** @@ -31,6 +58,7 @@ void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) { dh::DefaultStream().Wait(e); } } +#endif } bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {