Wait for data CUDA stream instead of sync. (#9144)

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2023-05-09 09:52:21 +08:00 committed by GitHub
parent a075aa24ba
commit 85988a3178
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 22 deletions

View File

@ -1352,14 +1352,12 @@ class CUDAStream {
cudaStream_t stream_; cudaStream_t stream_;
public: public:
CUDAStream() { CUDAStream() { dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); }
dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); ~CUDAStream() { dh::safe_cuda(cudaStreamDestroy(stream_)); }
}
~CUDAStream() { [[nodiscard]] CUDAStreamView View() const { return CUDAStreamView{stream_}; }
dh::safe_cuda(cudaStreamDestroy(stream_)); [[nodiscard]] cudaStream_t Handle() const { return stream_; }
}
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
void Sync() { this->View().Sync(); } void Sync() { this->View().Sync(); }
}; };

View File

@ -1,11 +1,15 @@
/*! /**
* Copyright 2021 by Contributors * Copyright 2021-2023, XGBoost Contributors
*/ */
#include <cstdint> // for int64_t
#include "../common/common.h" #include "../common/common.h"
#include "../common/device_helpers.cuh" // for DefaultStream, CUDAEvent
#include "array_interface.h" #include "array_interface.h"
#include "xgboost/logging.h"
namespace xgboost { namespace xgboost {
void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
switch (stream) { switch (stream) {
case 0: case 0:
/** /**
@ -22,8 +26,11 @@ void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
break; break;
case 2: case 2:
// default per-thread stream // default per-thread stream
default: default: {
dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream))); dh::CUDAEvent e;
e.Record(dh::CUDAStreamView{reinterpret_cast<cudaStream_t>(stream)});
dh::DefaultStream().Wait(e);
}
} }
} }

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2021 by Contributors * Copyright 2021-2023, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
@ -22,22 +22,19 @@ TEST(ArrayInterface, Stream) {
HostDeviceVector<float> storage; HostDeviceVector<float> storage;
auto arr_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); auto arr_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
cudaStream_t stream; dh::CUDAStream stream;
cudaStreamCreate(&stream);
auto j_arr =Json::Load(StringView{arr_str}); auto j_arr = Json::Load(StringView{arr_str});
j_arr["stream"] = Integer(reinterpret_cast<int64_t>(stream)); j_arr["stream"] = Integer(reinterpret_cast<int64_t>(stream.Handle()));
Json::Dump(j_arr, &arr_str); Json::Dump(j_arr, &arr_str);
dh::caching_device_vector<uint64_t> out(1, 0); dh::caching_device_vector<uint64_t> out(1, 0);
uint64_t dur = 1e9; std::uint64_t dur = 1e9;
dh::LaunchKernel{1, 1, 0, stream}(SleepForTest, out.data().get(), dur); dh::LaunchKernel{1, 1, 0, stream.View()}(SleepForTest, out.data().get(), dur);
ArrayInterface<2> arr(arr_str); ArrayInterface<2> arr(arr_str);
auto t = out[0]; auto t = out[0];
CHECK_GE(t, dur); CHECK_GE(t, dur);
cudaStreamDestroy(stream);
} }
TEST(ArrayInterface, Ptr) { TEST(ArrayInterface, Ptr) {