Wait for data CUDA stream instead of sync. (#9144)

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2023-05-09 09:52:21 +08:00
committed by GitHub
parent a075aa24ba
commit 85988a3178
3 changed files with 24 additions and 22 deletions

View File

@@ -1352,14 +1352,12 @@ class CUDAStream {
cudaStream_t stream_;
public:
CUDAStream() {
dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
~CUDAStream() {
dh::safe_cuda(cudaStreamDestroy(stream_));
}
CUDAStream() { dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); }
~CUDAStream() { dh::safe_cuda(cudaStreamDestroy(stream_)); }
[[nodiscard]] CUDAStreamView View() const { return CUDAStreamView{stream_}; }
[[nodiscard]] cudaStream_t Handle() const { return stream_; }
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
void Sync() { this->View().Sync(); }
};

View File

@@ -1,11 +1,15 @@
/*!
* Copyright 2021 by Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#include <cstdint> // for int64_t
#include "../common/common.h"
#include "../common/device_helpers.cuh" // for DefaultStream, CUDAEvent
#include "array_interface.h"
#include "xgboost/logging.h"
namespace xgboost {
void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
switch (stream) {
case 0:
/**
@@ -22,8 +26,11 @@ void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
break;
case 2:
// default per-thread stream
default:
dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)));
default: {
dh::CUDAEvent e;
e.Record(dh::CUDAStreamView{reinterpret_cast<cudaStream_t>(stream)});
dh::DefaultStream().Wait(e);
}
}
}