[EM] Prevent init with CUDA malloc resource. (#10606)
This commit is contained in:
@@ -404,7 +404,7 @@ size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bs
|
||||
LOG(FATAL) << "Concatenating the same Ellpack.";
|
||||
return this->n_rows * this->row_stride;
|
||||
}
|
||||
dh::LaunchN(num_elements, CopyPage{this, page, offset});
|
||||
dh::LaunchN(num_elements, ctx->CUDACtx()->Stream(), CopyPage{this, page, offset});
|
||||
monitor_.Stop(__func__);
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/cuda_rt_utils.h"
|
||||
#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
|
||||
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
|
||||
#include "../common/ref_resource_view.h" // for ReadVec, WriteVec
|
||||
@@ -21,6 +22,8 @@ namespace {
|
||||
template <typename T>
|
||||
[[nodiscard]] bool ReadDeviceVec(common::AlignedResourceReadStream* fi,
|
||||
common::RefResourceView<T>* vec) {
|
||||
xgboost_NVTX_FN_RANGE();
|
||||
|
||||
std::uint64_t n{0};
|
||||
if (!fi->Read(&n)) {
|
||||
return false;
|
||||
@@ -37,7 +40,7 @@ template <typename T>
|
||||
}
|
||||
|
||||
auto ctx = Context{}.MakeCUDA(common::CurrentDevice());
|
||||
*vec = common::MakeFixedVecWithCudaMalloc(&ctx, n, static_cast<T>(0));
|
||||
*vec = common::MakeFixedVecWithCudaMalloc<T>(&ctx, n);
|
||||
dh::safe_cuda(cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, dh::DefaultStream()));
|
||||
return true;
|
||||
}
|
||||
@@ -50,6 +53,7 @@ template <typename T>
|
||||
|
||||
[[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page,
|
||||
common::AlignedResourceReadStream* fi) {
|
||||
xgboost_NVTX_FN_RANGE();
|
||||
auto* impl = page->Impl();
|
||||
|
||||
impl->SetCuts(this->cuts_);
|
||||
@@ -69,6 +73,8 @@ template <typename T>
|
||||
|
||||
[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page,
|
||||
common::AlignedFileWriteStream* fo) {
|
||||
xgboost_NVTX_FN_RANGE();
|
||||
|
||||
std::size_t bytes{0};
|
||||
auto* impl = page.Impl();
|
||||
bytes += fo->Write(impl->n_rows);
|
||||
@@ -84,22 +90,30 @@ template <typename T>
|
||||
}
|
||||
|
||||
[[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page, EllpackHostCacheStream* fi) const {
|
||||
xgboost_NVTX_FN_RANGE();
|
||||
|
||||
auto* impl = page->Impl();
|
||||
CHECK(this->cuts_->cut_values_.DeviceCanRead());
|
||||
impl->SetCuts(this->cuts_);
|
||||
|
||||
// Read vector
|
||||
Context ctx = Context{}.MakeCUDA(common::CurrentDevice());
|
||||
auto read_vec = [&] {
|
||||
common::NvtxScopedRange range{common::NvtxEventAttr{"read-vec", common::NvtxRgb{127, 255, 0}}};
|
||||
bst_idx_t n{0};
|
||||
RET_IF_NOT(fi->Read(&n));
|
||||
if (n == 0) {
|
||||
return true;
|
||||
}
|
||||
impl->gidx_buffer = common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(&ctx, n);
|
||||
RET_IF_NOT(fi->Read(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes()));
|
||||
return true;
|
||||
};
|
||||
RET_IF_NOT(read_vec());
|
||||
|
||||
RET_IF_NOT(fi->Read(&impl->n_rows));
|
||||
RET_IF_NOT(fi->Read(&impl->is_dense));
|
||||
RET_IF_NOT(fi->Read(&impl->row_stride));
|
||||
|
||||
// Read vec
|
||||
Context ctx = Context{}.MakeCUDA(common::CurrentDevice());
|
||||
bst_idx_t n{0};
|
||||
RET_IF_NOT(fi->Read(&n));
|
||||
if (n != 0) {
|
||||
impl->gidx_buffer =
|
||||
common::MakeFixedVecWithCudaMalloc(&ctx, n, static_cast<common::CompressedByteT>(0));
|
||||
RET_IF_NOT(fi->Read(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes()));
|
||||
}
|
||||
RET_IF_NOT(fi->Read(&impl->base_rowid));
|
||||
|
||||
dh::DefaultStream().Sync();
|
||||
@@ -108,19 +122,27 @@ template <typename T>
|
||||
|
||||
[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page,
|
||||
EllpackHostCacheStream* fo) const {
|
||||
xgboost_NVTX_FN_RANGE();
|
||||
|
||||
bst_idx_t bytes{0};
|
||||
auto* impl = page.Impl();
|
||||
|
||||
// Write vector
|
||||
auto write_vec = [&] {
|
||||
common::NvtxScopedRange range{common::NvtxEventAttr{"write-vec", common::NvtxRgb{127, 255, 0}}};
|
||||
bst_idx_t n = impl->gidx_buffer.size();
|
||||
bytes += fo->Write(n);
|
||||
|
||||
if (!impl->gidx_buffer.empty()) {
|
||||
bytes += fo->Write(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes());
|
||||
}
|
||||
};
|
||||
|
||||
write_vec();
|
||||
|
||||
bytes += fo->Write(impl->n_rows);
|
||||
bytes += fo->Write(impl->is_dense);
|
||||
bytes += fo->Write(impl->row_stride);
|
||||
|
||||
// Write vector
|
||||
bst_idx_t n = impl->gidx_buffer.size();
|
||||
bytes += fo->Write(n);
|
||||
|
||||
if (!impl->gidx_buffer.empty()) {
|
||||
bytes += fo->Write(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes());
|
||||
}
|
||||
bytes += fo->Write(impl->base_rowid);
|
||||
|
||||
dh::DefaultStream().Sync();
|
||||
|
||||
Reference in New Issue
Block a user