diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 4b215ba58..bc38400e9 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -641,7 +641,7 @@ class DMatrix { typename XGDMatrixCallbackNext> static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr ref, DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing, - std::int32_t nthread, bst_bin_t max_bin, std::string cache); + std::int32_t nthread, bst_bin_t max_bin, std::string cache, bool on_host); virtual DMatrix *Slice(common::Span ridxs) = 0; diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 34faa4eb0..31ea232e5 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -116,6 +116,13 @@ inline int32_t CurrentDevice() { return device; } +// Helper function to get a device from a potentially CPU context. +inline auto GetDevice(xgboost::Context const *ctx) { + auto d = (ctx->IsCUDA()) ? ctx->Device() : xgboost::DeviceOrd::CUDA(dh::CurrentDevice()); + CHECK(!d.IsCPU()); + return d; +} + inline size_t TotalMemory(int device_idx) { size_t device_free = 0; size_t device_total = 0; diff --git a/src/data/data.cc b/src/data/data.cc index 8e4012dde..fcbd6cae2 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -914,9 +914,9 @@ template DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr ref, DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing, - std::int32_t nthread, bst_bin_t max_bin, std::string cache) { + std::int32_t nthread, bst_bin_t max_bin, std::string cache, bool on_host) { return new data::ExtMemQuantileDMatrix{ - iter, proxy, ref, reset, next, missing, nthread, std::move(cache), max_bin}; + iter, proxy, ref, reset, next, missing, nthread, std::move(cache), max_bin, on_host}; } template DMatrix* DMatrix::Create( DataIterHandle, DMatrixHandle, std::shared_ptr, DataIterResetCallback*, - XGDMatrixCallbackNext*, float, std::int32_t, bst_bin_t, std::string); + XGDMatrixCallbackNext*, float, std::int32_t, bst_bin_t, std::string, bool); template DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&, diff --git a/src/data/ellpack_page.cc b/src/data/ellpack_page.cc index dc6d370e5..4d918358d 100644 --- a/src/data/ellpack_page.cc +++ b/src/data/ellpack_page.cc @@ -47,6 +47,18 @@ bst_idx_t EllpackPage::Size() const { "EllpackPage is required"; return impl_->Cuts(); } + +[[nodiscard]] bst_idx_t EllpackPage::BaseRowId() const { + LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but " + "EllpackPage is required"; + return 0; +} + +[[nodiscard]] bool EllpackPage::IsDense() const { + LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but " + "EllpackPage is required"; + return false; +} } // namespace xgboost #endif // XGBOOST_USE_CUDA diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 4d240c1b7..575bcd5ce 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -39,6 +39,9 @@ void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id) return impl_->Cuts(); } +[[nodiscard]] bst_idx_t EllpackPage::BaseRowId() const { return this->Impl()->base_rowid; } +[[nodiscard]] bool EllpackPage::IsDense() const { return this->Impl()->IsDense(); } + // Bin each input data entry, store the bin indices in compressed form. __global__ void CompressBinEllpackKernel( common::CompressedBufferWriter wr, @@ -397,7 +400,7 @@ struct CopyPage { size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset) { monitor_.Start(__func__); bst_idx_t num_elements = page->n_rows * page->row_stride; - CHECK_EQ(row_stride, page->row_stride); + CHECK_EQ(this->row_stride, page->row_stride); CHECK_EQ(NumSymbols(), page->NumSymbols()); CHECK_GE(n_rows * row_stride, offset + num_elements); if (page == this) { diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 88873d0c2..97167ad5c 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -203,6 +203,7 @@ class EllpackPageImpl { [[nodiscard]] std::shared_ptr CutsShared() const { return cuts_; } void SetCuts(std::shared_ptr cuts) { cuts_ = cuts; } + [[nodiscard]] bool IsDense() const { return is_dense; } /** @return Estimation of memory cost of this page. */ static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ; diff --git a/src/data/ellpack_page.h b/src/data/ellpack_page.h index 246b48296..fa312f6e7 100644 --- a/src/data/ellpack_page.h +++ b/src/data/ellpack_page.h @@ -42,6 +42,7 @@ class EllpackPage { /*! \return Number of instances in the page. */ [[nodiscard]] bst_idx_t Size() const; + [[nodiscard]] bool IsDense() const; /*! \brief Set the base row id for this page. */ void SetBaseRowId(std::size_t row_id); @@ -50,6 +51,7 @@ class EllpackPage { EllpackPageImpl* Impl() { return impl_.get(); } [[nodiscard]] common::HistogramCuts const& Cuts() const; + [[nodiscard]] bst_idx_t BaseRowId() const; private: std::unique_ptr impl_; diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index a70d9150c..7ab4819e1 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -15,7 +15,8 @@ #include "ellpack_page.cuh" // for EllpackPageImpl #include "ellpack_page.h" // for EllpackPage #include "ellpack_page_source.h" -#include "xgboost/base.h" // for bst_idx_t +#include "proxy_dmatrix.cuh" // for Dispatch +#include "xgboost/base.h" // for bst_idx_t namespace xgboost::data { struct EllpackHostCache { @@ -182,4 +183,51 @@ template void EllpackPageSourceImpl>::Fetch(); template void EllpackPageSourceImpl>::Fetch(); + +/** + * ExtEllpackPageSourceImpl + */ +template +void ExtEllpackPageSourceImpl::Fetch() { + dh::safe_cuda(cudaSetDevice(this->Device().ordinal)); + if (!this->ReadCache()) { + auto iter = this->source_->Iter(); + CHECK_EQ(this->count_, iter); + ++(*this->source_); + CHECK_GE(this->source_->Iter(), 1); + cuda_impl::Dispatch(proxy_, [this](auto const& value) { + proxy_->Info().feature_types.SetDevice(dh::GetDevice(this->ctx_)); + auto d_feature_types = proxy_->Info().feature_types.ConstDeviceSpan(); + auto n_samples = value.NumRows(); + + dh::device_vector row_counts(n_samples + 1, 0); + common::Span row_counts_span(row_counts.data().get(), row_counts.size()); + cuda_impl::Dispatch(proxy_, [=](auto const& value) { + return GetRowCounts(value, row_counts_span, dh::GetDevice(this->ctx_), this->missing_); + }); + + this->page_.reset(new EllpackPage{}); + *this->page_->Impl() = EllpackPageImpl{this->ctx_, + value, + this->missing_, + this->info_->IsDense(), + row_counts_span, + d_feature_types, + this->ext_info_.row_stride, + n_samples, + this->GetCuts()}; + this->info_->Extend(proxy_->Info(), false, true); + }); + this->page_->SetBaseRowId(this->ext_info_.base_rows.at(iter)); + this->WriteCache(); + } +} + +// Instantiation +template void +ExtEllpackPageSourceImpl>::Fetch(); +template void +ExtEllpackPageSourceImpl>::Fetch(); +template void +ExtEllpackPageSourceImpl>::Fetch(); } // namespace xgboost::data diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index 1436f9151..987b120cb 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -8,6 +8,7 @@ #include // for int32_t #include // for shared_ptr #include // for move +#include // for vector #include "../common/cuda_rt_utils.h" // for SupportsPageableMem #include "../common/hist_util.h" // for HistogramCuts @@ -169,6 +170,51 @@ using EllpackPageHostSource = using EllpackPageSource = EllpackPageSourceImpl>; +template +class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin { + using Super = ExtQantileSourceMixin; + + Context const* ctx_; + BatchParam p_; + DMatrixProxy* proxy_; + MetaInfo* info_; + ExternalDataInfo ext_info_; + + std::vector base_rows_; + + public: + ExtEllpackPageSourceImpl( + Context const* ctx, float missing, MetaInfo* info, ExternalDataInfo ext_info, + std::shared_ptr cache, BatchParam param, std::shared_ptr cuts, + std::shared_ptr> source, + DMatrixProxy* proxy, std::vector base_rows) + : Super{missing, + ctx->Threads(), + static_cast(info->num_col_), + ext_info.n_batches, + source, + cache}, + ctx_{ctx}, + p_{std::move(param)}, + proxy_{proxy}, + info_{info}, + ext_info_{std::move(ext_info)}, + base_rows_{std::move(base_rows)} { + this->SetCuts(std::move(cuts), ctx->Device()); + this->Fetch(); + } + + void Fetch() final; +}; + +// Cache to host +using ExtEllpackPageHostSource = + ExtEllpackPageSourceImpl>; + +// Cache to disk +using ExtEllpackPageSource = + ExtEllpackPageSourceImpl>; + #if !defined(XGBOOST_USE_CUDA) template inline void EllpackPageSourceImpl::Fetch() { @@ -177,6 +223,11 @@ inline void EllpackPageSourceImpl::Fetch() { (void)(is_dense_); common::AssertGPUSupport(); } + +template +inline void ExtEllpackPageSourceImpl::Fetch() { + common::AssertGPUSupport(); +} #endif // !defined(XGBOOST_USE_CUDA) } // namespace xgboost::data diff --git a/src/data/extmem_quantile_dmatrix.cc b/src/data/extmem_quantile_dmatrix.cc index dd929afb0..0d17fcf55 100644 --- a/src/data/extmem_quantile_dmatrix.cc +++ b/src/data/extmem_quantile_dmatrix.cc @@ -24,8 +24,8 @@ ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrix DataIterResetCallback *reset, XGDMatrixCallbackNext *next, float missing, std::int32_t n_threads, std::string cache, - bst_bin_t max_bin) - : cache_prefix_{std::move(cache)} { + bst_bin_t max_bin, bool on_host) + : cache_prefix_{std::move(cache)}, on_host_{on_host} { auto iter = std::make_shared>( iter_handle, reset, next); iter->Reset(); @@ -72,13 +72,7 @@ void ExtMemQuantileDMatrix::InitFromCPU( common::HistogramCuts cuts; ExternalDataInfo ext_info; cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info); - - // From here on Info() has the correct data shape - this->Info().num_row_ = ext_info.accumulated_rows; - this->Info().num_col_ = ext_info.n_features; - this->Info().num_nonzero_ = ext_info.nnz; - this->Info().SynchronizeNumberOfColumns(ctx); - ext_info.Validate(); + ext_info.SetInfo(ctx, &this->info_); /** * Generate quantiles @@ -110,7 +104,7 @@ void ExtMemQuantileDMatrix::InitFromCPU( CHECK_EQ(n_total_samples, ext_info.accumulated_rows); } -BatchSet ExtMemQuantileDMatrix::GetGradientIndexImpl() { +[[nodiscard]] BatchSet ExtMemQuantileDMatrix::GetGradientIndexImpl() { return BatchSet{BatchIterator{this->ghist_index_source_}}; } @@ -148,5 +142,13 @@ BatchSet ExtMemQuantileDMatrix::GetEllpackBatches(Context const *, this->ellpack_page_source_); return batch_set; } + +BatchSet ExtMemQuantileDMatrix::GetEllpackPageImpl() { + common::AssertGPUSupport(); + auto batch_set = + std::visit([this](auto &&ptr) { return BatchSet{BatchIterator{ptr}}; }, + this->ellpack_page_source_); + return batch_set; +} #endif } // namespace xgboost::data diff --git a/src/data/extmem_quantile_dmatrix.cu b/src/data/extmem_quantile_dmatrix.cu index da59e5c9e..2612bbb69 100644 --- a/src/data/extmem_quantile_dmatrix.cu +++ b/src/data/extmem_quantile_dmatrix.cu @@ -4,21 +4,81 @@ #include // for shared_ptr #include // for visit +#include "batch_utils.h" // for CheckParam, RegenGHist +#include "ellpack_page.cuh" // for EllpackPage #include "extmem_quantile_dmatrix.h" +#include "proxy_dmatrix.h" // for DataIterProxy +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for BatchParam namespace xgboost::data { void ExtMemQuantileDMatrix::InitFromCUDA( - Context const *, std::shared_ptr>, - DMatrixHandle, BatchParam const &, float, std::shared_ptr) { - LOG(FATAL) << "Not implemented."; + Context const *ctx, + std::shared_ptr> iter, + DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr ref) { + // A handle passed to external iterator. + auto proxy = MakeProxy(proxy_handle); + CHECK(proxy); + + /** + * Generate quantiles + */ + auto cuts = std::make_shared(); + ExternalDataInfo ext_info; + cuda_impl::MakeSketches(ctx, iter.get(), proxy, ref, p, missing, cuts, this->Info(), &ext_info); + ext_info.SetInfo(ctx, &this->info_); + + /** + * Generate gradient index + */ + auto id = MakeCache(this, ".ellpack.page", false, cache_prefix_, &cache_info_); + if (on_host_ && std::get_if(&ellpack_page_source_) == nullptr) { + ellpack_page_source_.emplace(nullptr); + } + std::visit( + [&](auto &&ptr) { + using SourceT = typename std::remove_reference_t::element_type; + ptr = std::make_shared(ctx, missing, &this->Info(), ext_info, cache_info_.at(id), + p, cuts, iter, proxy, ext_info.base_rows); + }, + ellpack_page_source_); + + /** + * Force initialize the cache and do some sanity checks along the way + */ + bst_idx_t batch_cnt = 0, k = 0; + bst_idx_t n_total_samples = 0; + for (auto const &page : this->GetEllpackPageImpl()) { + n_total_samples += page.Size(); + CHECK_EQ(page.Impl()->base_rowid, ext_info.base_rows[k]); + CHECK_EQ(page.Impl()->row_stride, ext_info.row_stride); + ++k, ++batch_cnt; + } + CHECK_EQ(batch_cnt, ext_info.n_batches); + CHECK_EQ(n_total_samples, ext_info.accumulated_rows); } -BatchSet ExtMemQuantileDMatrix::GetEllpackBatches(Context const *, - const BatchParam &) { - LOG(FATAL) << "Not implemented."; +[[nodiscard]] BatchSet ExtMemQuantileDMatrix::GetEllpackPageImpl() { auto batch_set = std::visit([this](auto &&ptr) { return BatchSet{BatchIterator{ptr}}; }, this->ellpack_page_source_); return batch_set; } + +BatchSet ExtMemQuantileDMatrix::GetEllpackBatches(Context const *, + const BatchParam ¶m) { + if (param.Initialized()) { + detail::CheckParam(this->batch_, param); + CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin(); + } + + std::visit( + [this](auto &&ptr) { + CHECK(ptr); + ptr->Reset(); + }, + this->ellpack_page_source_); + + return this->GetEllpackPageImpl(); +} } // namespace xgboost::data diff --git a/src/data/extmem_quantile_dmatrix.h b/src/data/extmem_quantile_dmatrix.h index a6e9ed0ce..d3b9f5a78 100644 --- a/src/data/extmem_quantile_dmatrix.h +++ b/src/data/extmem_quantile_dmatrix.h @@ -30,7 +30,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix { ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy, std::shared_ptr ref, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, float missing, std::int32_t n_threads, - std::string cache, bst_bin_t max_bin); + std::string cache, bst_bin_t max_bin, bool on_host); ~ExtMemQuantileDMatrix() override; [[nodiscard]] bool SingleColBlock() const override { return false; } @@ -45,9 +45,10 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix { std::shared_ptr> iter, DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr ref); - BatchSet GetGradientIndexImpl(); + [[nodiscard]] BatchSet GetGradientIndexImpl(); BatchSet GetGradientIndex(Context const *ctx, BatchParam const ¶m) override; + [[nodiscard]] BatchSet GetEllpackPageImpl(); BatchSet GetEllpackBatches(Context const *ctx, const BatchParam ¶m) override; [[nodiscard]] bool EllpackExists() const override { @@ -60,10 +61,11 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix { std::map> cache_info_; std::string cache_prefix_; + bool on_host_; BatchParam batch_; - using EllpackDiskPtr = std::shared_ptr; - using EllpackHostPtr = std::shared_ptr; + using EllpackDiskPtr = std::shared_ptr; + using EllpackHostPtr = std::shared_ptr; std::variant ellpack_page_source_; std::shared_ptr ghist_index_source_; }; diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 00a7273a0..6c1a89079 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -242,6 +242,7 @@ class GHistIndexMatrix { [[nodiscard]] bool IsDense() const { return isDense_; } void SetDense(bool is_dense) { isDense_ = is_dense; } + [[nodiscard]] bst_idx_t BaseRowId() const { return base_rowid; } /** * @brief Get the local row index. */ diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h index b37dcab8b..535e86670 100644 --- a/src/data/gradient_index_page_source.h +++ b/src/data/gradient_index_page_source.h @@ -39,45 +39,6 @@ class GHistIndexFormatPolicy { void SetCuts(common::HistogramCuts cuts) { std::swap(cuts_, cuts); } }; -template > -class ExtQantileSourceMixin : public SparsePageSourceImpl { - protected: - std::shared_ptr> source_; - using Super = SparsePageSourceImpl; - - public: - ExtQantileSourceMixin(float missing, std::int32_t nthreads, bst_feature_t n_features, - bst_idx_t n_batches, std::shared_ptr cache) - : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache} {} - // This function always operate on the source first, then the downstream. The downstream - // can assume the source to be ready. - [[nodiscard]] ExtQantileSourceMixin& operator++() final { - TryLockGuard guard{this->single_threaded_}; - // Increment self. - ++this->count_; - // Set at end. - this->at_end_ = this->count_ == this->n_batches_; - - if (this->at_end_) { - this->EndIter(); - - CHECK(this->cache_info_->written); - source_ = nullptr; // release the source - } - this->Fetch(); - - return *this; - } - - void Reset() final { - if (this->source_) { - this->source_->Reset(); - } - Super::Reset(); - } -}; - class GradientIndexPageSource : public PageSourceIncMixIn< GHistIndexMatrix, DefaultFormatStreamPolicy> { @@ -124,15 +85,14 @@ class ExtGradientIndexPageSource std::shared_ptr cache, BatchParam param, common::HistogramCuts cuts, std::shared_ptr> source, DMatrixProxy* proxy, std::vector base_rows) - : ExtQantileSourceMixin{missing, ctx->Threads(), static_cast(info->num_col_), - n_batches, cache}, + : ExtQantileSourceMixin{missing, ctx->Threads(), static_cast(info->num_col_), + n_batches, source, cache}, p_{std::move(param)}, ctx_{ctx}, proxy_{proxy}, info_{info}, feature_types_{info_->feature_types.ConstHostSpan()}, base_rows_{std::move(base_rows)} { - this->source_ = source; this->SetCuts(std::move(cuts)); this->Fetch(); } diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 7c48edc02..29d389763 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -63,13 +63,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p, common::HistogramCuts cuts; ExternalDataInfo ext_info; cpu_impl::GetDataShape(ctx, proxy, iter, missing, &ext_info); - - // From here on Info() has the correct data shape - this->Info().num_row_ = ext_info.accumulated_rows; - this->Info().num_col_ = ext_info.n_features; - this->Info().num_nonzero_ = ext_info.nnz; - this->Info().SynchronizeNumberOfColumns(ctx); - ext_info.Validate(); + ext_info.SetInfo(ctx, &this->info_); /** * Generate quantiles diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index 1cdb840af..0b15604e3 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -1,20 +1,15 @@ /** * Copyright 2020-2024, XGBoost contributors */ -#include // for max #include // for shared_ptr #include // for move -#include // for vector -#include "../collective/allreduce.h" -#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs -#include "../common/hist_util.cuh" #include "batch_utils.h" // for RegenGHist, CheckParam #include "device_adapter.cuh" #include "ellpack_page.cuh" #include "iterative_dmatrix.h" #include "proxy_dmatrix.cuh" -#include "proxy_dmatrix.h" +#include "proxy_dmatrix.h" // for BatchSamples, BatchColumns #include "simple_batch_iterator.h" namespace xgboost::data { @@ -31,103 +26,32 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, dh::XGBCachingDeviceAllocator alloc; - auto num_rows = [&]() { - return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumRows(); }); - }; - auto num_cols = [&]() { - return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumCols(); }); - }; - - size_t row_stride = 0; - size_t nnz = 0; // Sketch for all batches. - std::vector sketch_containers; - size_t batches = 0; - size_t accumulated_rows = 0; - bst_feature_t cols = 0; - int32_t current_device; - dh::safe_cuda(cudaGetDevice(¤t_device)); + std::int32_t current_device{dh::CurrentDevice()}; auto get_ctx = [&]() { Context d_ctx = (ctx->IsCUDA()) ? *ctx : Context{}.MakeCUDA(current_device); CHECK(!d_ctx.IsCPU()); return d_ctx; }; - auto get_device = [&]() { - auto d = (ctx->IsCUDA()) ? ctx->Device() : DeviceOrd::CUDA(current_device); - CHECK(!d.IsCPU()); - return d; - }; + fmat_ctx_ = get_ctx(); /** * Generate quantiles */ auto cuts = std::make_shared(); - do { - // We use do while here as the first batch is fetched in ctor - CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs()); - dh::safe_cuda(cudaSetDevice(get_device().ordinal)); - if (cols == 0) { - cols = num_cols(); - auto rc = collective::Allreduce(ctx, linalg::MakeVec(&cols, 1), collective::Op::kMax); - SafeColl(rc); - this->info_.num_col_ = cols; - } else { - CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; - } - if (!ref) { - sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, cols, num_rows(), - get_device()); - auto* p_sketch = &sketch_containers.back(); - proxy->Info().weights_.SetDevice(get_device()); - cuda_impl::Dispatch(proxy, [&](auto const& value) { - common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch); - }); - } - auto batch_rows = num_rows(); - accumulated_rows += batch_rows; - dh::device_vector row_counts(batch_rows + 1, 0); - common::Span row_counts_span(row_counts.data().get(), row_counts.size()); - row_stride = std::max(row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) { - return GetRowCounts(value, row_counts_span, get_device(), missing); - })); - nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end()); - batches++; - } while (iter.Next()); - iter.Reset(); + ExternalDataInfo ext_info; + cuda_impl::MakeSketches(ctx, &iter, proxy, ref, p, missing, cuts, this->Info(), &ext_info); + ext_info.SetInfo(ctx, &this->info_); - auto n_features = cols; - CHECK_GE(n_features, 1) << "Data must has at least 1 column."; - - dh::safe_cuda(cudaSetDevice(get_device().ordinal)); - if (!ref) { - HostDeviceVector ft; - common::SketchContainer final_sketch( - sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin, cols, - accumulated_rows, get_device()); - for (auto const& sketch : sketch_containers) { - final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); - final_sketch.FixError(); - } - sketch_containers.clear(); - sketch_containers.shrink_to_fit(); - - final_sketch.MakeCuts(ctx, cuts.get(), this->info_.IsColumnSplit()); - } else { - GetCutsFromRef(ctx, ref, Info().num_col_, p, cuts.get()); - } - - this->info_.num_row_ = accumulated_rows; - this->info_.num_nonzero_ = nnz; - - auto init_page = [this, &cuts, row_stride, accumulated_rows, get_device]() { + auto init_page = [this, &cuts, &ext_info]() { if (!ellpack_) { // Should be put inside the while loop to protect against empty batch. In // that case device id is invalid. ellpack_.reset(new EllpackPage); - *(ellpack_->Impl()) = - EllpackPageImpl(&fmat_ctx_, cuts, this->IsDense(), row_stride, accumulated_rows); + *(ellpack_->Impl()) = EllpackPageImpl(&fmat_ctx_, cuts, this->IsDense(), ext_info.row_stride, + ext_info.accumulated_rows); } }; @@ -139,43 +63,42 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, size_t n_batches_for_verification = 0; while (iter.Next()) { init_page(); - dh::safe_cuda(cudaSetDevice(get_device().ordinal)); - auto rows = num_rows(); + dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal)); + auto rows = BatchSamples(proxy); dh::device_vector row_counts(rows + 1, 0); common::Span row_counts_span(row_counts.data().get(), row_counts.size()); cuda_impl::Dispatch(proxy, [=](auto const& value) { - return GetRowCounts(value, row_counts_span, get_device(), missing); + return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing); }); auto is_dense = this->IsDense(); - proxy->Info().feature_types.SetDevice(get_device()); + proxy->Info().feature_types.SetDevice(dh::GetDevice(ctx)); auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan(); auto new_impl = cuda_impl::Dispatch(proxy, [&](auto const& value) { return EllpackPageImpl(&fmat_ctx_, value, missing, is_dense, row_counts_span, d_feature_types, - row_stride, rows, cuts); + ext_info.row_stride, rows, cuts); }); std::size_t num_elements = ellpack_->Impl()->Copy(&fmat_ctx_, &new_impl, offset); offset += num_elements; - proxy->Info().num_row_ = num_rows(); - proxy->Info().num_col_ = cols; - if (batches != 1) { + proxy->Info().num_row_ = BatchSamples(proxy); + proxy->Info().num_col_ = ext_info.n_features; + if (ext_info.n_batches != 1) { this->info_.Extend(std::move(proxy->Info()), false, true); } n_batches_for_verification++; } - CHECK_EQ(batches, n_batches_for_verification) + CHECK_EQ(ext_info.n_batches, n_batches_for_verification) << "Different number of batches returned between 2 iterations"; - if (batches == 1) { + if (ext_info.n_batches == 1) { this->info_ = std::move(proxy->Info()); - this->info_.num_nonzero_ = nnz; + this->info_.num_nonzero_ = ext_info.nnz; CHECK_EQ(proxy->Info().labels.Size(), 0); } iter.Reset(); // Synchronise worker columns - info_.SynchronizeNumberOfColumns(ctx); } BatchSet IterativeDMatrix::GetEllpackBatches(Context const* ctx, diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index a3a78c1f2..8e62802c3 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -142,13 +142,14 @@ inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) { * @brief Shape and basic information for data fetched from an external data iterator. */ struct ExternalDataInfo { - std::uint64_t n_features = 0; // The number of columns + bst_idx_t n_features = 0; // The number of columns bst_idx_t n_batches = 0; // The number of batches bst_idx_t accumulated_rows = 0; // The total number of rows bst_idx_t nnz = 0; // The number of non-missing values std::vector column_sizes; // The nnz for each column std::vector batch_nnz; // nnz for each batch std::vector base_rows{0}; // base_rowid + bst_idx_t row_stride{0}; // Used by ellpack void Validate() const { CHECK(std::none_of(this->column_sizes.cbegin(), this->column_sizes.cend(), [&](auto f) { @@ -157,6 +158,16 @@ struct ExternalDataInfo { CHECK_GE(this->n_features, 1) << "Data must has at least 1 column."; } + + void SetInfo(Context const* ctx, MetaInfo* p_info) { + // From here on Info() has the correct data shape + auto& info = *p_info; + info.num_row_ = this->accumulated_rows; + info.num_col_ = this->n_features; + info.num_nonzero_ = this->nnz; + info.SynchronizeNumberOfColumns(ctx); + this->Validate(); + } }; /** diff --git a/src/data/quantile_dmatrix.cu b/src/data/quantile_dmatrix.cu index 884cf4f71..47ccadd4e 100644 --- a/src/data/quantile_dmatrix.cu +++ b/src/data/quantile_dmatrix.cu @@ -1,10 +1,93 @@ /** - * Copyright 2024, XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ -#include "ellpack_page.cuh" +#include // for max +#include // for partial_sum +#include // for vector + +#include "../collective/allreduce.h" // for Allreduce +#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs +#include "../common/device_vector.cuh" // for XGBCachingDeviceAllocator +#include "../common/hist_util.cuh" // for AdapterDeviceSketch +#include "../common/quantile.cuh" // for SketchContainer +#include "ellpack_page.cuh" // for EllpackPage +#include "proxy_dmatrix.cuh" // for Dispatch +#include "proxy_dmatrix.h" // for DataIterProxy +#include "quantile_dmatrix.h" // for GetCutsFromRef namespace xgboost::data { void GetCutsFromEllpack(EllpackPage const& page, common::HistogramCuts* cuts) { *cuts = page.Impl()->Cuts(); } + +namespace cuda_impl { +void MakeSketches(Context const* ctx, + DataIterProxy* iter, + DMatrixProxy* proxy, std::shared_ptr ref, BatchParam const& p, + float missing, std::shared_ptr cuts, MetaInfo const& info, + ExternalDataInfo* p_ext_info) { + dh::XGBCachingDeviceAllocator alloc; + std::vector sketch_containers; + auto& ext_info = *p_ext_info; + + do { + // We use do while here as the first batch is fetched in ctor + CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs()); + dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal)); + if (ext_info.n_features == 0) { + ext_info.n_features = data::BatchColumns(proxy); + auto rc = collective::Allreduce(ctx, linalg::MakeVec(&ext_info.n_features, 1), + collective::Op::kMax); + SafeColl(rc); + } else { + CHECK_EQ(ext_info.n_features, ::xgboost::data::BatchColumns(proxy)) + << "Inconsistent number of columns."; + } + if (!ref) { + sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, ext_info.n_features, + data::BatchSamples(proxy), dh::GetDevice(ctx)); + auto* p_sketch = &sketch_containers.back(); + proxy->Info().weights_.SetDevice(dh::GetDevice(ctx)); + cuda_impl::Dispatch(proxy, [&](auto const& value) { + common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch); + }); + } + auto batch_rows = data::BatchSamples(proxy); + ext_info.accumulated_rows += batch_rows; + dh::device_vector row_counts(batch_rows + 1, 0); + common::Span row_counts_span(row_counts.data().get(), row_counts.size()); + ext_info.row_stride = + std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) { + return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing); + })); + ext_info.nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end()); + ext_info.n_batches++; + ext_info.base_rows.push_back(batch_rows); + } while (iter->Next()); + iter->Reset(); + + CHECK_GE(ext_info.n_features, 1) << "Data must has at least 1 column."; + std::partial_sum(ext_info.base_rows.cbegin(), ext_info.base_rows.cend(), + ext_info.base_rows.begin()); + + // Get reference + dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal)); + if (!ref) { + HostDeviceVector ft; + common::SketchContainer final_sketch( + sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin, + ext_info.n_features, ext_info.accumulated_rows, dh::GetDevice(ctx)); + for (auto const& sketch : sketch_containers) { + final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); + final_sketch.FixError(); + } + sketch_containers.clear(); + sketch_containers.shrink_to_fit(); + + final_sketch.MakeCuts(ctx, cuts.get(), info.IsColumnSplit()); + } else { + GetCutsFromRef(ctx, ref, ext_info.n_features, p, cuts.get()); + } +} +} // namespace cuda_impl } // namespace xgboost::data diff --git a/src/data/quantile_dmatrix.h b/src/data/quantile_dmatrix.h index b5b7b2e2b..62d7e0e75 100644 --- a/src/data/quantile_dmatrix.h +++ b/src/data/quantile_dmatrix.h @@ -104,4 +104,12 @@ void MakeSketches(Context const *ctx, common::HistogramCuts *cuts, BatchParam const &p, MetaInfo const &info, ExternalDataInfo const &ext_info, std::vector *p_h_ft); } // namespace cpu_impl + +namespace cuda_impl { +void MakeSketches(Context const *ctx, + DataIterProxy *iter, + DMatrixProxy *proxy, std::shared_ptr ref, BatchParam const &p, + float missing, std::shared_ptr cuts, MetaInfo const &info, + ExternalDataInfo *p_ext_info); +} // namespace cuda_impl } // namespace xgboost::data diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 579b62c3b..202ead664 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -38,30 +38,23 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p auto iter = DataIterProxy{ iter_, reset_, next_}; - std::uint32_t n_batches = 0; - bst_feature_t n_features = 0; - bst_idx_t n_samples = 0; - bst_idx_t nnz = 0; + ExternalDataInfo ext_info; // The proxy is iterated together with the sparse page source so we can obtain all // information in 1 pass. for (auto const &page : this->GetRowBatchesImpl(&ctx)) { this->info_.Extend(std::move(proxy->Info()), false, false); - n_features = std::max(n_features, BatchColumns(proxy)); - n_samples += BatchSamples(proxy); - nnz += page.data.Size(); - n_batches++; + ext_info.n_features = + std::max(static_cast(ext_info.n_features), BatchColumns(proxy)); + ext_info.accumulated_rows += BatchSamples(proxy); + ext_info.nnz += page.data.Size(); + ext_info.n_batches++; } iter.Reset(); - this->n_batches_ = n_batches; - this->info_.num_row_ = n_samples; - this->info_.num_col_ = n_features; - this->info_.num_nonzero_ = nnz; - - info_.SynchronizeNumberOfColumns(&ctx); - CHECK_NE(info_.num_col_, 0); + this->n_batches_ = ext_info.n_batches; + ext_info.SetInfo(&ctx, &this->info_); fmat_ctx_ = ctx; } diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index e33fe8543..e750f00fc 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -585,5 +585,50 @@ class SortedCSCPageSource : public PageSourceIncMixIn { this->Fetch(); } }; + +/** + * @brief operator++ implementation for QDM. + */ +template > +class ExtQantileSourceMixin : public SparsePageSourceImpl { + protected: + std::shared_ptr> source_; + using Super = SparsePageSourceImpl; + + public: + ExtQantileSourceMixin( + float missing, std::int32_t nthreads, bst_feature_t n_features, bst_idx_t n_batches, + std::shared_ptr> source, + std::shared_ptr cache) + : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, + source_{std::move(source)} {} + // This function always operate on the source first, then the downstream. The downstream + // can assume the source to be ready. + [[nodiscard]] ExtQantileSourceMixin& operator++() final { + TryLockGuard guard{this->single_threaded_}; + // Increment self. + ++this->count_; + // Set at end. + this->at_end_ = this->count_ == this->n_batches_; + + if (this->at_end_) { + this->EndIter(); + + CHECK(this->cache_info_->written); + source_ = nullptr; // release the source + } + this->Fetch(); + + return *this; + } + + void Reset() final { + if (this->source_) { + this->source_->Reset(); + } + Super::Reset(); + } +}; } // namespace xgboost::data #endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ diff --git a/tests/cpp/data/test_extmem_quantile_dmatrix.cc b/tests/cpp/data/test_extmem_quantile_dmatrix.cc index e9d4b214b..623637ea4 100644 --- a/tests/cpp/data/test_extmem_quantile_dmatrix.cc +++ b/tests/cpp/data/test_extmem_quantile_dmatrix.cc @@ -1,6 +1,8 @@ /** * Copyright 2024, XGBoost Contributors */ +#include "test_extmem_quantile_dmatrix.h" // for TestExtMemQdmBasic + #include #include // for BatchParam @@ -9,76 +11,30 @@ #include "../../../src/common/column_matrix.h" // for ColumnMatrix #include "../../../src/data/gradient_index.h" // for GHistIndexMatrix #include "../../../src/tree/param.h" // for TrainParam -#include "../helpers.h" // for RandomDataGenerator namespace xgboost::data { namespace { class ExtMemQuantileDMatrixCpu : public ::testing::TestWithParam { public: void Run(float sparsity) { - bst_idx_t n_samples = 256, n_features = 16, n_batches = 4; - bst_bin_t max_bin = 64; - bst_target_t n_targets = 3; - auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity} - .Bins(max_bin) - .Batches(n_batches) - .Targets(n_targets) - .GenerateExtMemQuantileDMatrix("temp", true); - ASSERT_FALSE(p_fmat->SingleColBlock()); - - BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()}; - Context ctx; - - // Loop over the batches and count the number of pages - bst_idx_t batch_cnt = 0; - bst_idx_t base_cnt = 0; - bst_idx_t row_cnt = 0; - for (auto const& page : p_fmat->GetBatches(&ctx, p)) { - ASSERT_EQ(page.base_rowid, base_cnt); - ++batch_cnt; - base_cnt += n_samples / n_batches; - row_cnt += page.Size(); - ASSERT_EQ((sparsity == 0.0f), page.IsDense()); - } - ASSERT_EQ(n_batches, batch_cnt); - ASSERT_EQ(p_fmat->Info().num_row_, n_samples); - EXPECT_EQ(p_fmat->Info().num_row_, row_cnt); - ASSERT_EQ(p_fmat->Info().num_col_, n_features); - if (sparsity == 0.0f) { - ASSERT_EQ(p_fmat->Info().num_nonzero_, n_samples * n_features); - } else { - ASSERT_LT(p_fmat->Info().num_nonzero_, n_samples * n_features); - ASSERT_GT(p_fmat->Info().num_nonzero_, 0); - } - ASSERT_EQ(p_fmat->Info().labels.Shape(0), n_samples); - ASSERT_EQ(p_fmat->Info().labels.Shape(1), n_targets); - - // Compare against the sparse page DMatrix - auto p_sparse = RandomDataGenerator{n_samples, n_features, sparsity} - .Bins(max_bin) - .Batches(n_batches) - .Targets(n_targets) - .GenerateSparsePageDMatrix("temp", true); - auto it = p_fmat->GetBatches(&ctx, p).begin(); - for (auto const& page : p_sparse->GetBatches(&ctx, p)) { - auto orig = it.Page(); + auto equal = [](Context const*, GHistIndexMatrix const& orig, GHistIndexMatrix const& sparse) { // Check the CSR matrix - auto orig_cuts = it.Page()->Cuts(); - auto sparse_cuts = page.Cuts(); + auto orig_cuts = orig.Cuts(); + auto sparse_cuts = sparse.Cuts(); ASSERT_EQ(orig_cuts.Values(), sparse_cuts.Values()); ASSERT_EQ(orig_cuts.MinValues(), sparse_cuts.MinValues()); ASSERT_EQ(orig_cuts.Ptrs(), sparse_cuts.Ptrs()); - auto orig_ptr = orig->data.data(); - auto sparse_ptr = page.data.data(); - ASSERT_EQ(orig->data.size(), page.data.size()); + auto orig_ptr = orig.data.data(); + auto sparse_ptr = sparse.data.data(); + ASSERT_EQ(orig.data.size(), sparse.data.size()); - auto equal = std::equal(orig_ptr, orig_ptr + orig->data.size(), sparse_ptr); + auto equal = std::equal(orig_ptr, orig_ptr + orig.data.size(), sparse_ptr); ASSERT_TRUE(equal); // Check the column matrix - common::ColumnMatrix const& orig_columns = orig->Transpose(); - common::ColumnMatrix const& sparse_columns = page.Transpose(); + common::ColumnMatrix const& orig_columns = orig.Transpose(); + common::ColumnMatrix const& sparse_columns = sparse.Transpose(); std::string str_orig, str_sparse; common::AlignedMemWriteStream fo_orig{&str_orig}, fo_sparse{&str_sparse}; @@ -86,18 +42,10 @@ class ExtMemQuantileDMatrixCpu : public ::testing::TestWithParam { auto n_bytes_sparse = sparse_columns.Write(&fo_sparse); ASSERT_EQ(n_bytes_orig, n_bytes_sparse); ASSERT_EQ(str_orig, str_sparse); + }; - ++it; - } - - // Check meta info - auto h_y_sparse = p_sparse->Info().labels.HostView(); - auto h_y = p_fmat->Info().labels.HostView(); - for (std::size_t i = 0, m = h_y_sparse.Shape(0); i < m; ++i) { - for (std::size_t j = 0, n = h_y_sparse.Shape(1); j < n; ++j) { - ASSERT_EQ(h_y(i, j), h_y_sparse(i, j)); - } - } + Context ctx; + TestExtMemQdmBasic(&ctx, false, sparsity, equal); } }; } // anonymous namespace diff --git a/tests/cpp/data/test_extmem_quantile_dmatrix.cu b/tests/cpp/data/test_extmem_quantile_dmatrix.cu new file mode 100644 index 000000000..3b65dffa1 --- /dev/null +++ b/tests/cpp/data/test_extmem_quantile_dmatrix.cu @@ -0,0 +1,45 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include +#include // for BatchParam + +#include // for vector + +#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl +#include "../helpers.h" // for RandomDataGenerator +#include "test_extmem_quantile_dmatrix.h" // for TestExtMemQdmBasic + +namespace xgboost::data { +class ExtMemQuantileDMatrixGpu : public ::testing::TestWithParam { + public: + void Run(float sparsity) { + auto equal = [](Context const* ctx, EllpackPage const& orig, EllpackPage const& sparse) { + auto const& orig_cuts = orig.Cuts(); + auto const& sparse_cuts = sparse.Cuts(); + ASSERT_EQ(orig_cuts.Values(), sparse_cuts.Values()); + ASSERT_EQ(orig_cuts.MinValues(), sparse_cuts.MinValues()); + ASSERT_EQ(orig_cuts.Ptrs(), sparse_cuts.Ptrs()); + + std::vector h_orig, h_sparse; + auto orig_acc = orig.Impl()->GetHostAccessor(ctx, &h_orig, {}); + auto sparse_acc = sparse.Impl()->GetHostAccessor(ctx, &h_sparse, {}); + ASSERT_EQ(h_orig.size(), h_sparse.size()); + + auto equal = std::equal(h_orig.cbegin(), h_orig.cend(), h_sparse.cbegin()); + ASSERT_TRUE(equal); + }; + + auto ctx = MakeCUDACtx(0); + TestExtMemQdmBasic(&ctx, true, sparsity, equal); + TestExtMemQdmBasic(&ctx, false, sparsity, equal); + } +}; + +TEST_P(ExtMemQuantileDMatrixGpu, Basic) { this->Run(this->GetParam()); } + +INSTANTIATE_TEST_SUITE_P(ExtMemQuantileDMatrix, ExtMemQuantileDMatrixGpu, ::testing::ValuesIn([] { + std::vector sparsities{0.0f, 0.2f, 0.4f, 0.8f}; + return sparsities; + }())); +} // namespace xgboost::data diff --git a/tests/cpp/data/test_extmem_quantile_dmatrix.h b/tests/cpp/data/test_extmem_quantile_dmatrix.h new file mode 100644 index 000000000..25f2e0654 --- /dev/null +++ b/tests/cpp/data/test_extmem_quantile_dmatrix.h @@ -0,0 +1,73 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include +#include + +#include "../../../src/tree/param.h" // for TrainParam +#include "../helpers.h" // for RandomDataGenerator + +namespace xgboost::data { +template +void TestExtMemQdmBasic(Context const* ctx, bool on_host, float sparsity, Equal&& check_equal) { + bst_idx_t n_samples = 256, n_features = 16, n_batches = 4; + bst_bin_t max_bin = 64; + bst_target_t n_targets = 3; + BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()}; + + auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity} + .Bins(max_bin) + .Batches(n_batches) + .Targets(n_targets) + .Device(ctx->Device()) + .OnHost(on_host) + .GenerateExtMemQuantileDMatrix("temp", true); + ASSERT_FALSE(p_fmat->SingleColBlock()); + + // Loop over the batches and count the number of pages + bst_idx_t batch_cnt = 0, base_cnt = 0, row_cnt = 0; + for (auto const& page : p_fmat->GetBatches(ctx, p)) { + ASSERT_EQ(page.BaseRowId(), base_cnt); + ++batch_cnt; + base_cnt += n_samples / n_batches; + row_cnt += page.Size(); + ASSERT_EQ((sparsity == 0.0f), page.IsDense()); + } + ASSERT_EQ(n_batches, batch_cnt); + ASSERT_EQ(p_fmat->Info().num_row_, n_samples); + EXPECT_EQ(p_fmat->Info().num_row_, row_cnt); + ASSERT_EQ(p_fmat->Info().num_col_, n_features); + if (sparsity == 0.0f) { + ASSERT_EQ(p_fmat->Info().num_nonzero_, n_samples * n_features); + } else { + ASSERT_LT(p_fmat->Info().num_nonzero_, n_samples * n_features); + ASSERT_GT(p_fmat->Info().num_nonzero_, 0); + } + ASSERT_EQ(p_fmat->Info().labels.Shape(0), n_samples); + ASSERT_EQ(p_fmat->Info().labels.Shape(1), n_targets); + + // Compare against the sparse page DMatrix + auto p_sparse = RandomDataGenerator{n_samples, n_features, sparsity} + .Bins(max_bin) + .Batches(n_batches) + .Targets(n_targets) + .Device(ctx->Device()) + .OnHost(on_host) + .GenerateSparsePageDMatrix("temp", true); + auto it = p_fmat->GetBatches(ctx, p).begin(); + for (auto const& page : p_sparse->GetBatches(ctx, p)) { + auto orig = it.Page(); + check_equal(ctx, *orig, page); + ++it; + } + + // Check meta info + auto h_y_sparse = p_sparse->Info().labels.HostView(); + auto h_y = p_fmat->Info().labels.HostView(); + for (std::size_t i = 0, m = h_y_sparse.Shape(0); i < m; ++i) { + for (std::size_t j = 0, n = h_y_sparse.Shape(1); j < n; ++j) { + ASSERT_EQ(h_y(i, j), h_y_sparse(i, j)); + } + } +} +} // namespace xgboost::data diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 036f761ff..05f843164 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -483,12 +483,15 @@ void RandomDataGenerator::GenerateCSR( } CHECK(iter); - std::shared_ptr p_fmat{ - DMatrix::Create(static_cast(iter.get()), iter->Proxy(), nullptr, Reset, Next, - std::numeric_limits::quiet_NaN(), 0, this->bins_, prefix)}; + std::shared_ptr p_fmat{DMatrix::Create( + static_cast(iter.get()), iter->Proxy(), nullptr, Reset, Next, + std::numeric_limits::quiet_NaN(), 0, this->bins_, prefix, this->on_host_)}; - auto page_path = data::MakeId(prefix, p_fmat.get()) + ".gradient_index.page"; - EXPECT_TRUE(FileExists(page_path)) << page_path; + auto page_path = data::MakeId(prefix, p_fmat.get()); + page_path += device_.IsCPU() ? ".gradient_index.page" : ".ellpack.page"; + if (!this->on_host_) { + EXPECT_TRUE(FileExists(page_path)) << page_path; + } if (with_label) { RandomDataGenerator{static_cast(p_fmat->Info().num_row_), this->n_targets_, 0.0f}