[EM] Prevent init with CUDA malloc resource. (#10606)

This commit is contained in:
Jiaming Yuan
2024-07-21 05:08:29 +08:00
committed by GitHub
parent 0846ad860c
commit cb62f9e73b
7 changed files with 105 additions and 41 deletions

View File

@@ -404,7 +404,7 @@ size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bs
LOG(FATAL) << "Concatenating the same Ellpack.";
return this->n_rows * this->row_stride;
}
dh::LaunchN(num_elements, CopyPage{this, page, offset});
dh::LaunchN(num_elements, ctx->CUDACtx()->Stream(), CopyPage{this, page, offset});
monitor_.Stop(__func__);
return num_elements;
}

View File

@@ -6,6 +6,7 @@
#include <cstddef> // for size_t
#include <vector> // for vector
#include "../common/cuda_rt_utils.h"
#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
#include "../common/ref_resource_view.h" // for ReadVec, WriteVec
@@ -21,6 +22,8 @@ namespace {
template <typename T>
[[nodiscard]] bool ReadDeviceVec(common::AlignedResourceReadStream* fi,
common::RefResourceView<T>* vec) {
xgboost_NVTX_FN_RANGE();
std::uint64_t n{0};
if (!fi->Read(&n)) {
return false;
@@ -37,7 +40,7 @@ template <typename T>
}
auto ctx = Context{}.MakeCUDA(common::CurrentDevice());
*vec = common::MakeFixedVecWithCudaMalloc(&ctx, n, static_cast<T>(0));
*vec = common::MakeFixedVecWithCudaMalloc<T>(&ctx, n);
dh::safe_cuda(cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, dh::DefaultStream()));
return true;
}
@@ -50,6 +53,7 @@ template <typename T>
[[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page,
common::AlignedResourceReadStream* fi) {
xgboost_NVTX_FN_RANGE();
auto* impl = page->Impl();
impl->SetCuts(this->cuts_);
@@ -69,6 +73,8 @@ template <typename T>
[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page,
common::AlignedFileWriteStream* fo) {
xgboost_NVTX_FN_RANGE();
std::size_t bytes{0};
auto* impl = page.Impl();
bytes += fo->Write(impl->n_rows);
@@ -84,22 +90,30 @@ template <typename T>
}
[[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page, EllpackHostCacheStream* fi) const {
xgboost_NVTX_FN_RANGE();
auto* impl = page->Impl();
CHECK(this->cuts_->cut_values_.DeviceCanRead());
impl->SetCuts(this->cuts_);
// Read vector
Context ctx = Context{}.MakeCUDA(common::CurrentDevice());
auto read_vec = [&] {
common::NvtxScopedRange range{common::NvtxEventAttr{"read-vec", common::NvtxRgb{127, 255, 0}}};
bst_idx_t n{0};
RET_IF_NOT(fi->Read(&n));
if (n == 0) {
return true;
}
impl->gidx_buffer = common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(&ctx, n);
RET_IF_NOT(fi->Read(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes()));
return true;
};
RET_IF_NOT(read_vec());
RET_IF_NOT(fi->Read(&impl->n_rows));
RET_IF_NOT(fi->Read(&impl->is_dense));
RET_IF_NOT(fi->Read(&impl->row_stride));
// Read vec
Context ctx = Context{}.MakeCUDA(common::CurrentDevice());
bst_idx_t n{0};
RET_IF_NOT(fi->Read(&n));
if (n != 0) {
impl->gidx_buffer =
common::MakeFixedVecWithCudaMalloc(&ctx, n, static_cast<common::CompressedByteT>(0));
RET_IF_NOT(fi->Read(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes()));
}
RET_IF_NOT(fi->Read(&impl->base_rowid));
dh::DefaultStream().Sync();
@@ -108,19 +122,27 @@ template <typename T>
[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page,
EllpackHostCacheStream* fo) const {
xgboost_NVTX_FN_RANGE();
bst_idx_t bytes{0};
auto* impl = page.Impl();
// Write vector
auto write_vec = [&] {
common::NvtxScopedRange range{common::NvtxEventAttr{"write-vec", common::NvtxRgb{127, 255, 0}}};
bst_idx_t n = impl->gidx_buffer.size();
bytes += fo->Write(n);
if (!impl->gidx_buffer.empty()) {
bytes += fo->Write(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes());
}
};
write_vec();
bytes += fo->Write(impl->n_rows);
bytes += fo->Write(impl->is_dense);
bytes += fo->Write(impl->row_stride);
// Write vector
bst_idx_t n = impl->gidx_buffer.size();
bytes += fo->Write(n);
if (!impl->gidx_buffer.empty()) {
bytes += fo->Write(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes());
}
bytes += fo->Write(impl->base_rowid);
dh::DefaultStream().Sync();