Wait for data CUDA stream instead of sync. (#9144)
--------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
a075aa24ba
commit
85988a3178
@ -1352,14 +1352,12 @@ class CUDAStream {
|
|||||||
cudaStream_t stream_;
|
cudaStream_t stream_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CUDAStream() {
|
CUDAStream() { dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); }
|
||||||
dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
~CUDAStream() { dh::safe_cuda(cudaStreamDestroy(stream_)); }
|
||||||
}
|
|
||||||
~CUDAStream() {
|
[[nodiscard]] CUDAStreamView View() const { return CUDAStreamView{stream_}; }
|
||||||
dh::safe_cuda(cudaStreamDestroy(stream_));
|
[[nodiscard]] cudaStream_t Handle() const { return stream_; }
|
||||||
}
|
|
||||||
|
|
||||||
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
|
|
||||||
void Sync() { this->View().Sync(); }
|
void Sync() { this->View().Sync(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,15 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 by Contributors
|
* Copyright 2021-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include <cstdint> // for int64_t
|
||||||
|
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
|
#include "../common/device_helpers.cuh" // for DefaultStream, CUDAEvent
|
||||||
#include "array_interface.h"
|
#include "array_interface.h"
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
|
void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
|
||||||
switch (stream) {
|
switch (stream) {
|
||||||
case 0:
|
case 0:
|
||||||
/**
|
/**
|
||||||
@ -22,8 +26,11 @@ void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
|
|||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
// default per-thread stream
|
// default per-thread stream
|
||||||
default:
|
default: {
|
||||||
dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)));
|
dh::CUDAEvent e;
|
||||||
|
e.Record(dh::CUDAStreamView{reinterpret_cast<cudaStream_t>(stream)});
|
||||||
|
dh::DefaultStream().Wait(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 by Contributors
|
* Copyright 2021-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
@ -22,22 +22,19 @@ TEST(ArrayInterface, Stream) {
|
|||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
auto arr_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
|
auto arr_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
|
||||||
|
|
||||||
cudaStream_t stream;
|
dh::CUDAStream stream;
|
||||||
cudaStreamCreate(&stream);
|
|
||||||
|
|
||||||
auto j_arr =Json::Load(StringView{arr_str});
|
auto j_arr = Json::Load(StringView{arr_str});
|
||||||
j_arr["stream"] = Integer(reinterpret_cast<int64_t>(stream));
|
j_arr["stream"] = Integer(reinterpret_cast<int64_t>(stream.Handle()));
|
||||||
Json::Dump(j_arr, &arr_str);
|
Json::Dump(j_arr, &arr_str);
|
||||||
|
|
||||||
dh::caching_device_vector<uint64_t> out(1, 0);
|
dh::caching_device_vector<uint64_t> out(1, 0);
|
||||||
uint64_t dur = 1e9;
|
std::uint64_t dur = 1e9;
|
||||||
dh::LaunchKernel{1, 1, 0, stream}(SleepForTest, out.data().get(), dur);
|
dh::LaunchKernel{1, 1, 0, stream.View()}(SleepForTest, out.data().get(), dur);
|
||||||
ArrayInterface<2> arr(arr_str);
|
ArrayInterface<2> arr(arr_str);
|
||||||
|
|
||||||
auto t = out[0];
|
auto t = out[0];
|
||||||
CHECK_GE(t, dur);
|
CHECK_GE(t, dur);
|
||||||
|
|
||||||
cudaStreamDestroy(stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArrayInterface, Ptr) {
|
TEST(ArrayInterface, Ptr) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user