diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu b/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu index bd428189f..9621a6147 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu +++ b/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu @@ -132,7 +132,7 @@ class DataIteratorProxy { bool cache_on_host_{true}; // TODO(Bobby): Make this optional. template - using Alloc = xgboost::common::cuda_impl::pinned_allocator; + using Alloc = xgboost::common::cuda_impl::PinnedAllocator; template using HostVector = std::vector>; diff --git a/src/common/cuda_pinned_allocator.h b/src/common/cuda_pinned_allocator.h index c53ae4517..90c34668a 100644 --- a/src/common/cuda_pinned_allocator.h +++ b/src/common/cuda_pinned_allocator.h @@ -21,7 +21,6 @@ namespace xgboost::common::cuda_impl { // that Thrust used to provide. // // \see https://en.cppreference.com/w/cpp/memory/allocator - template struct PinnedAllocPolicy { using pointer = T*; // NOLINT: The type returned by address() / allocate() @@ -33,7 +32,7 @@ struct PinnedAllocPolicy { return std::numeric_limits::max() / sizeof(value_type); } - pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT + [[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT if (cnt > this->max_size()) { throw std::bad_alloc{}; } // end if @@ -57,7 +56,7 @@ struct ManagedAllocPolicy { return std::numeric_limits::max() / sizeof(value_type); } - pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT + [[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT if (cnt > this->max_size()) { throw std::bad_alloc{}; } // end if @@ -70,16 +69,49 @@ struct ManagedAllocPolicy { void deallocate(pointer p, size_type) { dh::safe_cuda(cudaFree(p)); } // NOLINT }; -template typename Policy> -class CudaHostAllocatorImpl : public Policy { // NOLINT - public: - using value_type = typename Policy::value_type; // NOLINT - using pointer = typename Policy::pointer; // NOLINT - using const_pointer = typename Policy::const_pointer; // NOLINT - using size_type = typename Policy::size_type; // NOLINT +// This is actually a pinned memory allocator in disguise. We utilize HMM or ATS for +// efficient tracked memory allocation. +template +struct SamAllocPolicy { + using pointer = T*; // NOLINT: The type returned by address() / allocate() + using const_pointer = const T*; // NOLINT: The type returned by address() + using size_type = std::size_t; // NOLINT: The type used for the size of the allocation + using value_type = T; // NOLINT: The type of the elements in the allocator - using reference = T&; // NOLINT: The parameter type for address() - using const_reference = const T&; // NOLINT: The parameter type for address() + size_type max_size() const { // NOLINT + return std::numeric_limits::max() / sizeof(value_type); + } + + [[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT + if (cnt > this->max_size()) { + throw std::bad_alloc{}; + } // end if + + size_type n_bytes = cnt * sizeof(value_type); + pointer result = reinterpret_cast(std::malloc(n_bytes)); + if (!result) { + throw std::bad_alloc{}; + } + dh::safe_cuda(cudaHostRegister(result, n_bytes, cudaHostRegisterDefault)); + return result; + } + + void deallocate(pointer p, size_type) { // NOLINT + dh::safe_cuda(cudaHostUnregister(p)); + std::free(p); + } +}; + +template typename Policy> +class CudaHostAllocatorImpl : public Policy { + public: + using typename Policy::value_type; + using typename Policy::pointer; + using typename Policy::const_pointer; + using typename Policy::size_type; + + using reference = value_type&; // NOLINT: The parameter type for address() + using const_reference = const value_type&; // NOLINT: The parameter type for address() using difference_type = std::ptrdiff_t; // NOLINT: The type of the distance between two pointers @@ -101,14 +133,17 @@ class CudaHostAllocatorImpl : public Policy { // NOLINT pointer address(reference r) { return &r; } // NOLINT const_pointer address(const_reference r) { return &r; } // NOLINT - bool operator==(CudaHostAllocatorImpl const& x) const { return true; } + bool operator==(CudaHostAllocatorImpl const&) const { return true; } bool operator!=(CudaHostAllocatorImpl const& x) const { return !operator==(x); } }; template -using pinned_allocator = CudaHostAllocatorImpl; // NOLINT +using PinnedAllocator = CudaHostAllocatorImpl; // NOLINT template -using managed_allocator = CudaHostAllocatorImpl; // NOLINT +using ManagedAllocator = CudaHostAllocatorImpl; // NOLINT + +template +using SamAllocator = CudaHostAllocatorImpl; } // namespace xgboost::common::cuda_impl diff --git a/src/common/io.h b/src/common/io.h index 5f2e28336..c82015c31 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -286,6 +286,7 @@ class ResourceHandler { kMmap = 1, kCudaMalloc = 2, kCudaMmap = 3, + kCudaHostCache = 4, }; private: @@ -310,6 +311,8 @@ class ResourceHandler { return "CudaMalloc"; case kCudaMmap: return "CudaMmap"; + case kCudaHostCache: + return "CudaHostCache"; } LOG(FATAL) << "Unreachable."; return {}; diff --git a/src/common/ref_resource_view.cuh b/src/common/ref_resource_view.cuh index d48b221a3..985938e08 100644 --- a/src/common/ref_resource_view.cuh +++ b/src/common/ref_resource_view.cuh @@ -16,8 +16,7 @@ namespace xgboost::common { * @brief Make a fixed size `RefResourceView` with cudaMalloc resource. */ template -[[nodiscard]] RefResourceView MakeFixedVecWithCudaMalloc(Context const*, - std::size_t n_elements) { +[[nodiscard]] RefResourceView MakeFixedVecWithCudaMalloc(std::size_t n_elements) { auto resource = std::make_shared(n_elements * sizeof(T)); auto ref = RefResourceView{resource->DataAs(), n_elements, resource}; return ref; @@ -26,8 +25,15 @@ template template [[nodiscard]] RefResourceView MakeFixedVecWithCudaMalloc(Context const* ctx, std::size_t n_elements, T const& init) { - auto ref = MakeFixedVecWithCudaMalloc(ctx, n_elements); + auto ref = MakeFixedVecWithCudaMalloc(n_elements); thrust::fill_n(ctx->CUDACtx()->CTP(), ref.data(), ref.size(), init); return ref; } + +template +[[nodiscard]] RefResourceView MakeFixedVecWithPinnedMalloc(std::size_t n_elements) { + auto resource = std::make_shared(n_elements * sizeof(T)); + auto ref = RefResourceView{resource->DataAs(), n_elements, resource}; + return ref; +} } // namespace xgboost::common diff --git a/src/common/resource.cuh b/src/common/resource.cuh index e950a8d90..4930f8368 100644 --- a/src/common/resource.cuh +++ b/src/common/resource.cuh @@ -5,9 +5,10 @@ #include // for size_t #include // for function -#include "device_vector.cuh" // for DeviceUVector -#include "io.h" // for ResourceHandler, MMAPFile -#include "xgboost/string_view.h" // for StringView +#include "cuda_pinned_allocator.h" // for SamAllocator +#include "device_vector.cuh" // for DeviceUVector +#include "io.h" // for ResourceHandler, MMAPFile +#include "xgboost/string_view.h" // for StringView namespace xgboost::common { /** @@ -29,6 +30,22 @@ class CudaMallocResource : public ResourceHandler { void Resize(std::size_t n_bytes) { this->storage_.resize(n_bytes); } }; +class CudaPinnedResource : public ResourceHandler { + std::vector> storage_; + + void Clear() noexcept(true) { this->Resize(0); } + + public: + explicit CudaPinnedResource(std::size_t n_bytes) : ResourceHandler{kCudaHostCache} { + this->Resize(n_bytes); + } + ~CudaPinnedResource() noexcept(true) override { this->Clear(); } + + [[nodiscard]] void* Data() override { return storage_.data(); } + [[nodiscard]] std::size_t Size() const override { return storage_.size(); } + void Resize(std::size_t n_bytes) { this->storage_.resize(n_bytes); } +}; + class CudaMmapResource : public ResourceHandler { std::unique_ptr> handle_; std::size_t n_; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index bb279b3d8..515625c24 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -404,7 +404,7 @@ size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bs bst_idx_t num_elements = page->n_rows * page->row_stride; CHECK_EQ(this->row_stride, page->row_stride); CHECK_EQ(NumSymbols(), page->NumSymbols()); - CHECK_GE(n_rows * row_stride, offset + num_elements); + CHECK_GE(this->n_rows * this->row_stride, offset + num_elements); if (page == this) { LOG(FATAL) << "Concatenating the same Ellpack."; return this->n_rows * this->row_stride; @@ -542,7 +542,10 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device, // Return the number of rows contained in this page. [[nodiscard]] bst_idx_t EllpackPageImpl::Size() const { return n_rows; } -std::size_t EllpackPageImpl::MemCostBytes() const { return this->gidx_buffer.size_bytes(); } +std::size_t EllpackPageImpl::MemCostBytes() const { + return this->gidx_buffer.size_bytes() + sizeof(this->n_rows) + sizeof(this->is_dense) + + sizeof(this->row_stride) + sizeof(this->base_rowid); +} EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( DeviceOrd device, common::Span feature_types) const { diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index af11dec3f..9cc2a5130 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -66,6 +66,7 @@ struct EllpackDeviceAccessor { min_fvalue = cuts->min_vals_.ConstHostSpan(); } } + /** * @brief Given a row index and a feature index, returns the corresponding cut value. * @@ -75,7 +76,7 @@ struct EllpackDeviceAccessor { * local to the current batch. */ template - [[nodiscard]] __device__ bst_bin_t GetBinIndex(size_t ridx, size_t fidx) const { + [[nodiscard]] __device__ bst_bin_t GetBinIndex(bst_idx_t ridx, size_t fidx) const { if (global_ridx) { ridx -= base_rowid; } @@ -114,7 +115,7 @@ struct EllpackDeviceAccessor { return idx; } - [[nodiscard]] __device__ float GetFvalue(size_t ridx, size_t fidx) const { + [[nodiscard]] __device__ float GetFvalue(bst_idx_t ridx, size_t fidx) const { auto gidx = GetBinIndex(ridx, fidx); if (gidx == -1) { return std::numeric_limits::quiet_NaN(); diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 86d1ac6da..e17d72000 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -39,8 +39,7 @@ template return false; } - auto ctx = Context{}.MakeCUDA(common::CurrentDevice()); - *vec = common::MakeFixedVecWithCudaMalloc(&ctx, n); + *vec = common::MakeFixedVecWithCudaMalloc(n); dh::safe_cuda(cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, dh::DefaultStream())); return true; } @@ -96,27 +95,9 @@ template CHECK(this->cuts_->cut_values_.DeviceCanRead()); impl->SetCuts(this->cuts_); - // Read vector - Context ctx = Context{}.MakeCUDA(common::CurrentDevice()); - auto read_vec = [&] { - common::NvtxScopedRange range{common::NvtxEventAttr{"read-vec", common::NvtxRgb{127, 255, 0}}}; - bst_idx_t n{0}; - RET_IF_NOT(fi->Read(&n)); - if (n == 0) { - return true; - } - impl->gidx_buffer = common::MakeFixedVecWithCudaMalloc(&ctx, n); - RET_IF_NOT(fi->Read(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes())); - return true; - }; - RET_IF_NOT(read_vec()); - - RET_IF_NOT(fi->Read(&impl->n_rows)); - RET_IF_NOT(fi->Read(&impl->is_dense)); - RET_IF_NOT(fi->Read(&impl->row_stride)); - RET_IF_NOT(fi->Read(&impl->base_rowid)); - + fi->Read(page); dh::DefaultStream().Sync(); + return true; } @@ -124,29 +105,11 @@ template EllpackHostCacheStream* fo) const { xgboost_NVTX_FN_RANGE(); - bst_idx_t bytes{0}; - auto* impl = page.Impl(); - - // Write vector - auto write_vec = [&] { - common::NvtxScopedRange range{common::NvtxEventAttr{"write-vec", common::NvtxRgb{127, 255, 0}}}; - bst_idx_t n = impl->gidx_buffer.size(); - bytes += fo->Write(n); - - if (!impl->gidx_buffer.empty()) { - bytes += fo->Write(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes()); - } - }; - - write_vec(); - - bytes += fo->Write(impl->n_rows); - bytes += fo->Write(impl->is_dense); - bytes += fo->Write(impl->row_stride); - bytes += fo->Write(impl->base_rowid); - + fo->Write(page); dh::DefaultStream().Sync(); - return bytes; + + auto* impl = page.Impl(); + return impl->MemCostBytes(); } #undef RET_IF_NOT diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 342ac8da7..4c49dbc9a 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -6,9 +6,11 @@ #include // for size_t #include // for int8_t, uint64_t, uint32_t #include // for shared_ptr, make_unique, make_shared +#include // for accumulate #include // for move #include "../common/common.h" // for safe_cuda +#include "../common/ref_resource_view.cuh" #include "../common/cuda_pinned_allocator.h" // for pinned_allocator #include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream #include "../common/resource.cuh" // for PrivateCudaMmapConstStream @@ -17,50 +19,91 @@ #include "ellpack_page_source.h" #include "proxy_dmatrix.cuh" // for Dispatch #include "xgboost/base.h" // for bst_idx_t +#include "../common/cuda_rt_utils.h" // for NvtxScopedRange +#include "../common/transform_iterator.h" // for MakeIndexTransformIter namespace xgboost::data { -struct EllpackHostCache { - thrust::host_vector> cache; +/** + * Cache + */ +EllpackHostCache::EllpackHostCache() = default; +EllpackHostCache::~EllpackHostCache() = default; - void Resize(std::size_t n, dh::CUDAStreamView stream) { - stream.Sync(); // Prevent partial copy inside resize. - cache.resize(n); - } -}; +[[nodiscard]] std::size_t EllpackHostCache::Size() const { + auto it = common::MakeIndexTransformIter([&](auto i) { return pages.at(i)->MemCostBytes(); }); + return std::accumulate(it, it + pages.size(), 0l); +} +void EllpackHostCache::Push(std::unique_ptr page) { + this->pages.emplace_back(std::move(page)); +} + +EllpackPageImpl const* EllpackHostCache::Get(std::int32_t k) { + return this->pages.at(k).get(); +} + +/** + * Cache stream. + */ class EllpackHostCacheStreamImpl { std::shared_ptr cache_; - bst_idx_t cur_ptr_{0}; - bst_idx_t bound_{0}; + std::int32_t ptr_; public: explicit EllpackHostCacheStreamImpl(std::shared_ptr cache) : cache_{std::move(cache)} {} - [[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes) { - auto n = cur_ptr_ + n_bytes; - if (n > cache_->cache.size()) { - cache_->Resize(n, dh::DefaultStream()); + auto Share() { return cache_; } + + void Seek(bst_idx_t offset_bytes) { + std::size_t n_bytes{0}; + std::int32_t k{-1}; + for (std::size_t i = 0, n = cache_->pages.size(); i < n; ++i) { + if (n_bytes == offset_bytes) { + k = i; + break; + } + n_bytes += cache_->pages[i]->MemCostBytes(); } - dh::safe_cuda(cudaMemcpyAsync(cache_->cache.data() + cur_ptr_, ptr, n_bytes, cudaMemcpyDefault, - dh::DefaultStream())); - cur_ptr_ = n; - return n_bytes; + if (offset_bytes == n_bytes && k == -1) { + k = this->cache_->pages.size(); // seek end + } + CHECK_NE(k, -1) << "Invalid offset:" << offset_bytes; + ptr_ = k; } - [[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes) { - CHECK_LE(cur_ptr_ + n_bytes, bound_); - dh::safe_cuda(cudaMemcpyAsync(ptr, cache_->cache.data() + cur_ptr_, n_bytes, cudaMemcpyDefault, - dh::DefaultStream())); - cur_ptr_ += n_bytes; - return true; + void Write(EllpackPage const& page) { + auto impl = page.Impl(); + + auto new_impl = std::make_unique(); + auto new_cache = std::make_shared(); + new_impl->gidx_buffer = + common::MakeFixedVecWithPinnedMalloc(impl->gidx_buffer.size()); + new_impl->n_rows = impl->Size(); + new_impl->is_dense = impl->IsDense(); + new_impl->row_stride = impl->row_stride; + new_impl->base_rowid = impl->base_rowid; + + dh::safe_cuda(cudaMemcpyAsync(new_impl->gidx_buffer.data(), impl->gidx_buffer.data(), + impl->gidx_buffer.size_bytes(), cudaMemcpyDefault)); + + this->cache_->Push(std::move(new_impl)); + ptr_ += 1; } - [[nodiscard]] bst_idx_t Tell() const { return cur_ptr_; } - void Seek(bst_idx_t offset_bytes) { cur_ptr_ = offset_bytes; } - void Bound(bst_idx_t offset_bytes) { - CHECK_LE(offset_bytes, cache_->cache.size()); - this->bound_ = offset_bytes; + void Read(EllpackPage* out) const { + auto page = this->cache_->Get(ptr_); + + auto impl = out->Impl(); + impl->gidx_buffer = + common::MakeFixedVecWithCudaMalloc(page->gidx_buffer.size()); + dh::safe_cuda(cudaMemcpyAsync(impl->gidx_buffer.data(), page->gidx_buffer.data(), + page->gidx_buffer.size_bytes(), cudaMemcpyDefault)); + + impl->n_rows = page->Size(); + impl->is_dense = page->IsDense(); + impl->row_stride = page->row_stride; + impl->base_rowid = page->base_rowid; } }; @@ -73,19 +116,13 @@ EllpackHostCacheStream::EllpackHostCacheStream(std::shared_ptr EllpackHostCacheStream::~EllpackHostCacheStream() = default; -[[nodiscard]] bst_idx_t EllpackHostCacheStream::Write(void const* ptr, bst_idx_t n_bytes) { - return this->p_impl_->Write(ptr, n_bytes); -} - -[[nodiscard]] bool EllpackHostCacheStream::Read(void* ptr, bst_idx_t n_bytes) { - return this->p_impl_->Read(ptr, n_bytes); -} - -[[nodiscard]] bst_idx_t EllpackHostCacheStream::Tell() const { return this->p_impl_->Tell(); } +std::shared_ptr EllpackHostCacheStream::Share() { return p_impl_->Share(); } void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); } -void EllpackHostCacheStream::Bound(bst_idx_t offset_bytes) { this->p_impl_->Bound(offset_bytes); } +void EllpackHostCacheStream::Read(EllpackPage* page) const { this->p_impl_->Read(page); } + +void EllpackHostCacheStream::Write(EllpackPage const& page) { this->p_impl_->Write(page); } /** * EllpackCacheStreamPolicy @@ -100,20 +137,18 @@ template typename F> EllpackCacheStreamPolicy::CreateWriter(StringView, std::uint32_t iter) { auto fo = std::make_unique(this->p_cache_); if (iter == 0) { - CHECK(this->p_cache_->cache.empty()); + CHECK(this->p_cache_->Empty()); } else { - fo->Seek(this->p_cache_->cache.size()); + fo->Seek(this->p_cache_->Size()); } return fo; } template typename F> [[nodiscard]] std::unique_ptr::ReaderT> -EllpackCacheStreamPolicy::CreateReader(StringView, bst_idx_t offset, bst_idx_t length) const { +EllpackCacheStreamPolicy::CreateReader(StringView, bst_idx_t offset, bst_idx_t) const { auto fi = std::make_unique(this->p_cache_); fi->Seek(offset); - fi->Bound(offset + length); - CHECK_EQ(fi->Tell(), offset); return fi; } diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index 987b120cb..61f94a262 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -22,10 +22,22 @@ namespace xgboost::data { // We need to decouple the storage and the view of the storage so that we can implement -// concurrent read. +// concurrent read. As a result, there are two classes, one for cache storage, another one +// for stream. +struct EllpackHostCache { + std::vector> pages; + + EllpackHostCache(); + ~EllpackHostCache(); + + [[nodiscard]] std::size_t Size() const; + + bool Empty() const { return this->Size() == 0; } + + void Push(std::unique_ptr page); + EllpackPageImpl const* Get(std::int32_t k); +}; -// Dummy type to hide CUDA calls from the host compiler. -struct EllpackHostCache; // Pimpl to hide CUDA calls from the host compiler. class EllpackHostCacheStreamImpl; @@ -37,24 +49,12 @@ class EllpackHostCacheStream { explicit EllpackHostCacheStream(std::shared_ptr cache); ~EllpackHostCacheStream(); - [[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes); - template - [[nodiscard]] std::enable_if_t, bst_idx_t> Write(T const& v) { - return this->Write(&v, sizeof(T)); - } + std::shared_ptr Share(); - [[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes); - - template - [[nodiscard]] auto Read(T* ptr) -> std::enable_if_t, bool> { - return this->Read(ptr, sizeof(T)); - } - - [[nodiscard]] bst_idx_t Tell() const; void Seek(bst_idx_t offset_bytes); - // Limit the size of read. offset_bytes is the maximum offset that this stream can read - // to. An error is raised if the limited is exceeded. - void Bound(bst_idx_t offset_bytes); + + void Read(EllpackPage* page) const; + void Write(EllpackPage const& page); }; template @@ -86,6 +86,7 @@ class EllpackFormatPolicy { CHECK(cuts_); return cuts_; } + [[nodiscard]] auto Device() const { return device_; } }; diff --git a/src/data/extmem_quantile_dmatrix.cu b/src/data/extmem_quantile_dmatrix.cu index 2612bbb69..f7f033e95 100644 --- a/src/data/extmem_quantile_dmatrix.cu +++ b/src/data/extmem_quantile_dmatrix.cu @@ -4,8 +4,9 @@ #include // for shared_ptr #include // for visit -#include "batch_utils.h" // for CheckParam, RegenGHist -#include "ellpack_page.cuh" // for EllpackPage +#include "../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE +#include "batch_utils.h" // for CheckParam, RegenGHist +#include "ellpack_page.cuh" // for EllpackPage #include "extmem_quantile_dmatrix.h" #include "proxy_dmatrix.h" // for DataIterProxy #include "xgboost/context.h" // for Context @@ -16,6 +17,8 @@ void ExtMemQuantileDMatrix::InitFromCUDA( Context const *ctx, std::shared_ptr> iter, DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr ref) { + xgboost_NVTX_FN_RANGE(); + // A handle passed to external iterator. auto proxy = MakeProxy(proxy_handle); CHECK(proxy); @@ -31,10 +34,11 @@ void ExtMemQuantileDMatrix::InitFromCUDA( /** * Generate gradient index */ - auto id = MakeCache(this, ".ellpack.page", false, cache_prefix_, &cache_info_); + auto id = MakeCache(this, ".ellpack.page", this->on_host_, cache_prefix_, &cache_info_); if (on_host_ && std::get_if(&ellpack_page_source_) == nullptr) { ellpack_page_source_.emplace(nullptr); } + std::visit( [&](auto &&ptr) { using SourceT = typename std::remove_reference_t::element_type; @@ -56,6 +60,7 @@ void ExtMemQuantileDMatrix::InitFromCUDA( } CHECK_EQ(batch_cnt, ext_info.n_batches); CHECK_EQ(n_total_samples, ext_info.accumulated_rows); + this->n_batches_ = ext_info.n_batches; } [[nodiscard]] BatchSet ExtMemQuantileDMatrix::GetEllpackPageImpl() { diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 202ead664..eb9da871b 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -1,11 +1,17 @@ /** - * Copyright 2014-2023 by XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file sparse_page_dmatrix.cc * * \brief The external memory version of Page Iterator. * \author Tianqi Chen */ -#include "./sparse_page_dmatrix.h" +#include "sparse_page_dmatrix.h" + +#include // for max +#include // for make_shared +#include // for string +#include // for move +#include // for visit #include "../collective/communicator-inl.h" #include "batch_utils.h" // for RegenGHist diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index e82bcbf82..4be8e108f 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -57,7 +57,7 @@ struct CatAccessor { class GPUHistEvaluator { using CatST = common::CatBitField::value_type; // categorical storage type // use pinned memory to stage the categories, used for sort based splits. - using Alloc = xgboost::common::cuda_impl::pinned_allocator; + using Alloc = xgboost::common::cuda_impl::PinnedAllocator; private: TreeEvaluator tree_evaluator_; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 573261f9c..5d364aa82 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -735,7 +735,7 @@ class GPUHistMaker : public TreeUpdater { void Update(TrainParam const* param, linalg::Matrix* gpair, DMatrix* dmat, common::Span> out_position, const std::vector& trees) override { - monitor_.Start("Update"); + monitor_.Start(__func__); CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented(); auto gpair_hdv = gpair->Data(); @@ -747,7 +747,7 @@ class GPUHistMaker : public TreeUpdater { ++t_idx; } dh::safe_cuda(cudaGetLastError()); - monitor_.Stop("Update"); + monitor_.Stop(__func__); } void InitDataOnce(TrainParam const* param, DMatrix* dmat) { @@ -858,7 +858,7 @@ class GPUGlobalApproxMaker : public TreeUpdater { void Update(TrainParam const* param, linalg::Matrix* gpair, DMatrix* p_fmat, common::Span> out_position, const std::vector& trees) override { - monitor_.Start("Update"); + monitor_.Start(__func__); this->InitDataOnce(p_fmat); // build tree @@ -884,7 +884,7 @@ class GPUGlobalApproxMaker : public TreeUpdater { ++t_idx; } - monitor_.Stop("Update"); + monitor_.Stop(__func__); } void InitDataOnce(DMatrix* p_fmat) { diff --git a/tests/cpp/common/test_cuda_host_allocator.cu b/tests/cpp/common/test_cuda_host_allocator.cu index c8e25564a..4e3224bd8 100644 --- a/tests/cpp/common/test_cuda_host_allocator.cu +++ b/tests/cpp/common/test_cuda_host_allocator.cu @@ -12,7 +12,7 @@ namespace xgboost { TEST(CudaHostMalloc, Pinned) { - std::vector> vec; + std::vector> vec; vec.resize(10); ASSERT_EQ(vec.size(), 10); Context ctx; @@ -25,7 +25,7 @@ TEST(CudaHostMalloc, Pinned) { } TEST(CudaHostMalloc, Managed) { - std::vector> vec; + std::vector> vec; vec.resize(10); #if defined(__linux__) dh::safe_cuda( diff --git a/tests/cpp/data/test_ellpack_page_raw_format.cu b/tests/cpp/data/test_ellpack_page_raw_format.cu index b7bb5f902..05aec905a 100644 --- a/tests/cpp/data/test_ellpack_page_raw_format.cu +++ b/tests/cpp/data/test_ellpack_page_raw_format.cu @@ -77,7 +77,50 @@ TEST(EllpackPageRawFormat, DiskIOHmm) { } TEST(EllpackPageRawFormat, HostIO) { - EllpackCacheStreamPolicy policy; - TestEllpackPageRawFormat(&policy); + { + EllpackCacheStreamPolicy policy; + TestEllpackPageRawFormat(&policy); + } + { + auto ctx = MakeCUDACtx(0); + auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()}; + EllpackCacheStreamPolicy policy; + std::unique_ptr format{}; + Cache cache{false, "name", "ellpack", true}; + for (std::size_t i = 0; i < 3; ++i) { + auto p_fmat = RandomDataGenerator{100, 14, 0.5}.Seed(i).GenerateDMatrix(); + for (auto const &page : p_fmat->GetBatches(&ctx, param)) { + if (!format) { + policy.SetCuts(page.Impl()->CutsShared(), ctx.Device()); + format = policy.CreatePageFormat(); + } + auto writer = policy.CreateWriter({}, i); + auto n_bytes = format->Write(page, writer.get()); + ASSERT_EQ(n_bytes, page.Impl()->MemCostBytes()); + cache.Push(n_bytes); + } + } + cache.Commit(); + + for (std::size_t i = 0; i < 3; ++i) { + auto reader = policy.CreateReader({}, cache.offset[i], cache.Bytes(i)); + EllpackPage page; + ASSERT_TRUE(format->Read(&page, reader.get())); + ASSERT_EQ(page.Impl()->MemCostBytes(), cache.Bytes(i)); + auto p_fmat = RandomDataGenerator{100, 14, 0.5}.Seed(i).GenerateDMatrix(); + for (auto const &orig : p_fmat->GetBatches(&ctx, param)) { + std::vector h_orig; + auto h_acc_orig = orig.Impl()->GetHostAccessor(&ctx, &h_orig, {}); + std::vector h_page; + auto h_acc = page.Impl()->GetHostAccessor(&ctx, &h_page, {}); + ASSERT_EQ(h_orig, h_page); + ASSERT_EQ(h_acc_orig.NumFeatures(), h_acc.NumFeatures()); + ASSERT_EQ(h_acc_orig.row_stride, h_acc.row_stride); + ASSERT_EQ(h_acc_orig.n_rows, h_acc.n_rows); + ASSERT_EQ(h_acc_orig.base_rowid, h_acc.base_rowid); + ASSERT_EQ(h_acc_orig.is_dense, h_acc.is_dense); + } + } + } } } // namespace xgboost::data