Wait for data CUDA stream instead of sync. (#9144)
--------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1352,14 +1352,12 @@ class CUDAStream {
|
||||
cudaStream_t stream_;
|
||||
|
||||
public:
|
||||
CUDAStream() {
|
||||
dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
||||
}
|
||||
~CUDAStream() {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream_));
|
||||
}
|
||||
CUDAStream() { dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); }
|
||||
~CUDAStream() { dh::safe_cuda(cudaStreamDestroy(stream_)); }
|
||||
|
||||
[[nodiscard]] CUDAStreamView View() const { return CUDAStreamView{stream_}; }
|
||||
[[nodiscard]] cudaStream_t Handle() const { return stream_; }
|
||||
|
||||
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
|
||||
void Sync() { this->View().Sync(); }
|
||||
};
|
||||
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2021 by Contributors
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#include <cstdint> // for int64_t
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/device_helpers.cuh" // for DefaultStream, CUDAEvent
|
||||
#include "array_interface.h"
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
|
||||
void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
|
||||
switch (stream) {
|
||||
case 0:
|
||||
/**
|
||||
@@ -22,8 +26,11 @@ void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
|
||||
break;
|
||||
case 2:
|
||||
// default per-thread stream
|
||||
default:
|
||||
dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)));
|
||||
default: {
|
||||
dh::CUDAEvent e;
|
||||
e.Record(dh::CUDAStreamView{reinterpret_cast<cudaStream_t>(stream)});
|
||||
dh::DefaultStream().Wait(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user