restore stream logic
This commit is contained in:
parent
b3ee7a59c7
commit
1b6c6baf76
@ -11,6 +11,33 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
|
void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
if (stream == 1 || stream == 2) {
|
||||||
|
// Default streams, no synchronization needed
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
hipEvent_t event;
|
||||||
|
hipError_t err = hipEventCreate(&event);
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
LOG(WARNING) << "Failed to create HIP event: " << hipGetErrorString(err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
err = hipEventRecord(event, reinterpret_cast<hipStream_t>(stream));
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
LOG(WARNING) << "Failed to record HIP event: " << hipGetErrorString(err);
|
||||||
|
hipEventDestroy(event);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
err = hipStreamWaitEvent(hipStreamPerThread, event, 0);
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
LOG(WARNING) << "Failed to wait for HIP event: " << hipGetErrorString(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
hipEventDestroy(event);
|
||||||
|
#else
|
||||||
switch (stream) {
|
switch (stream) {
|
||||||
case 0:
|
case 0:
|
||||||
/**
|
/**
|
||||||
@ -31,6 +58,7 @@ void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
|
|||||||
dh::DefaultStream().Wait(e);
|
dh::DefaultStream().Wait(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user