From 9cb4c938da8a1e1bcb3794344838687302c21a4e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 3 Jul 2024 03:48:18 +0800 Subject: [PATCH] [EM] Move prefetch in reset into the end of the iteration. (#10529) --- src/data/gradient_index_page_source.cc | 3 + src/data/sparse_page_source.h | 78 ++++++++++++++-------- tests/cpp/data/test_sparse_page_dmatrix.cc | 29 ++++++-- tests/cpp/data/test_sparse_page_dmatrix.cu | 71 +++++++++++++++----- 4 files changed, 133 insertions(+), 48 deletions(-) diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index f1ceb282a..0fee1c9fb 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -9,6 +9,9 @@ void GradientIndexPageSource::Fetch() { if (count_ != 0 && !sync_) { // source is initialized to be the 0th page during construction, so when count_ is 0 // there's no need to increment the source. + // + // The mixin doesn't sync the source if `sync_` is false, we need to sync it + // ourselves. ++(*source_); } // This is not read from cache so we still need it to be synced with sparse page source. diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 89aa86ace..18a149059 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -42,7 +42,7 @@ struct Cache { std::string name; std::string format; // offset into binary cache file. - std::vector offset; + std::vector offset; Cache(bool w, std::string n, std::string fmt, bool on_host) : written{w}, on_host{on_host}, name{std::move(n)}, format{std::move(fmt)} { @@ -61,7 +61,7 @@ struct Cache { /** * @brief Record a page with size of n_bytes. */ - void Push(std::size_t n_bytes) { offset.push_back(n_bytes); } + void Push(bst_idx_t n_bytes) { offset.push_back(n_bytes); } /** * @brief Returns the view start and length for the i^th page. */ @@ -73,7 +73,7 @@ struct Cache { /** * @brief Get the number of bytes for the i^th page. */ - [[nodiscard]] std::uint64_t Bytes(std::size_t i) const { return offset.at(i + 1) - offset[i]; } + [[nodiscard]] bst_idx_t Bytes(std::size_t i) const { return offset.at(i + 1) - offset[i]; } /** * @brief Call this once the write for the cache is complete. */ @@ -218,7 +218,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol common::Monitor monitor_; [[nodiscard]] bool ReadCache() { - CHECK(!at_end_); if (!cache_info_->written) { return false; } @@ -259,11 +258,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol return page; }); } + CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }), n_prefetch_batches) << "Sparse DMatrix assumes forward iteration."; monitor_.Start("Wait"); + CHECK((*ring_)[count_].valid()); page_ = (*ring_)[count_].get(); CHECK(!(*ring_)[count_].valid()); monitor_.Stop("Wait"); @@ -331,12 +332,28 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol return at_end_; } + // Call this at the last iteration. + void EndIter() { + CHECK_EQ(this->cache_info_->offset.size(), this->n_batches_ + 1); + this->cache_info_->Commit(); + if (this->n_batches_ != 0) { + CHECK_EQ(this->count_, this->n_batches_); + } + CHECK_GE(this->count_, 1); + this->count_ = 0; + } + virtual void Reset() { TryLockGuard guard{single_threaded_}; - at_end_ = false; - count_ = 0; - // Pre-fetch for the next round of iterations. - this->Fetch(); + + this->at_end_ = false; + auto cnt = this->count_; + this->count_ = 0; + if (cnt != 0) { + // The last iteration did not get to the end, clear the ring to start from 0. + this->ring_ = std::make_unique(); + this->Fetch(); + } } }; @@ -404,16 +421,11 @@ class SparsePageSource : public SparsePageSourceImpl { CHECK_LE(count_, n_batches_); if (at_end_) { - CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1); - cache_info_->Commit(); - if (n_batches_ != 0) { - CHECK_EQ(count_, n_batches_); - } - CHECK_GE(count_, 1); - proxy_ = nullptr; - } else { - this->Fetch(); + this->EndIter(); + this->proxy_ = nullptr; } + + this->Fetch(); return *this; } @@ -446,36 +458,46 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features, bst_idx_t n_batches, std::shared_ptr cache, bool sync) : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {} - + // This function always operate on the source first, then the downstream. The downstream + // can assume the source to be ready. [[nodiscard]] PageSourceIncMixIn& operator++() final { TryLockGuard guard{this->single_threaded_}; + // Increment the source. if (sync_) { ++(*source_); } - + // Increment self. ++this->count_; + // Set at end. this->at_end_ = this->count_ == this->n_batches_; if (this->at_end_) { - this->cache_info_->Commit(); - if (this->n_batches_ != 0) { - CHECK_EQ(this->count_, this->n_batches_); + // If this is the first round of iterations, we have just built the binary cache + // from soruce. For a non-sync page type, the source hasn't been updated to the end + // iteration yet due to skipped increment. We increment the source here and it will + // call the `EndIter` method itself. + bool src_need_inc = !sync_ && this->source_->Iter() != 0; + if (src_need_inc) { + CHECK_EQ(this->source_->Iter(), this->count_ - 1); + ++(*source_); + } + this->EndIter(); + + if (src_need_inc) { + CHECK(this->cache_info_->written); } - CHECK_GE(this->count_, 1); - } else { - this->Fetch(); } + this->Fetch(); if (sync_) { + // Sanity check. CHECK_EQ(source_->Iter(), this->count_); } return *this; } void Reset() final { - if (sync_) { - this->source_->Reset(); - } + this->source_->Reset(); Super::Reset(); } }; diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 33308be19..3aeb42abc 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -118,7 +118,8 @@ TEST(SparsePageDMatrix, RetainSparsePage) { // Test GHistIndexMatrix can avoid loading sparse page after the initialization. TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) { dmlc::TemporaryDirectory tmpdir; - auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix( + std::size_t n_batches = 6; + auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix( tmpdir.path + "/", true); Context ctx; bst_bin_t n_bins{256}; @@ -171,12 +172,30 @@ TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) { // Restore the batch parameter by passing it in again through check_ghist check_ghist(); } + // half the pages - auto it = Xy->GetBatches(&ctx).begin(); - for (std::int32_t i = 0; i < 3; ++i) { - ++it; + { + auto it = Xy->GetBatches(&ctx).begin(); + for (std::size_t i = 0; i < n_batches / 2; ++i) { + ++it; + } + check_ghist(); + } + { + auto it = Xy->GetBatches(&ctx, batch_param).begin(); + for (std::size_t i = 0; i < n_batches / 2; ++i) { + ++it; + } + check_ghist(); + } + { + BatchParam regen{n_bins, common::Span{hess.data(), hess.size()}, true}; + auto it = Xy->GetBatches(&ctx, regen).begin(); + for (std::size_t i = 0; i < n_batches / 2; ++i) { + ++it; + } + check_ghist(); } - check_ghist(); } TEST(SparsePageDMatrix, MetaInfo) { diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index 7200b96a9..327f2ba63 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -41,31 +41,77 @@ TEST(SparsePageDMatrix, EllpackPage) { TEST(SparsePageDMatrix, EllpackSkipSparsePage) { // Test Ellpack can avoid loading sparse page after the initialization. dmlc::TemporaryDirectory tmpdir; - auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix( + std::size_t n_batches = 6; + auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix( tmpdir.path + "/", true); auto ctx = MakeCUDACtx(0); + auto cpu = ctx.MakeCPU(); bst_bin_t n_bins{256}; double sparse_thresh{0.8}; BatchParam batch_param{n_bins, sparse_thresh}; - std::int32_t k = 0; - for (auto const& page : Xy->GetBatches(&ctx, batch_param)) { - auto impl = page.Impl(); - ASSERT_EQ(page.Size(), 30); - ASSERT_EQ(k, impl->base_rowid); - k += page.Size(); - } + auto check_ellpack = [&]() { + std::int32_t k = 0; + for (auto const& page : Xy->GetBatches(&ctx, batch_param)) { + auto impl = page.Impl(); + ASSERT_EQ(page.Size(), 30); + ASSERT_EQ(k, impl->base_rowid); + k += page.Size(); + } + }; auto casted = std::dynamic_pointer_cast(Xy); CHECK(casted); + check_ellpack(); + // Make the number of fetches don't change (no new fetch) auto n_fetches = casted->SparsePageFetchCount(); - for (std::int32_t i = 0; i < 3; ++i) { + for (std::size_t i = 0; i < 3; ++i) { for ([[maybe_unused]] auto const& page : Xy->GetBatches(&ctx, batch_param)) { } auto casted = std::dynamic_pointer_cast(Xy); ASSERT_EQ(casted->SparsePageFetchCount(), n_fetches); } + check_ellpack(); + + dh::device_vector hess(Xy->Info().num_row_, 1.0f); + for (std::size_t i = 0; i < 4; ++i) { + for ([[maybe_unused]] auto const& page : Xy->GetBatches(&ctx)) { + } + for ([[maybe_unused]] auto const& page : Xy->GetBatches(&cpu)) { + } + for ([[maybe_unused]] auto const& page : Xy->GetBatches(&ctx, batch_param)) { + } + // Approx tree method pages + { + BatchParam regen{n_bins, dh::ToSpan(hess), false}; + for ([[maybe_unused]] auto const& page : Xy->GetBatches(&ctx, regen)) { + } + } + { + BatchParam regen{n_bins, dh::ToSpan(hess), true}; + for ([[maybe_unused]] auto const& page : Xy->GetBatches(&ctx, regen)) { + } + } + + check_ellpack(); + } + + // half the pages + { + auto it = Xy->GetBatches(&ctx).begin(); + for (std::size_t i = 0; i < n_batches / 2; ++i) { + ++it; + } + check_ellpack(); + } + { + auto it = Xy->GetBatches(&ctx, batch_param).begin(); + for (std::size_t i = 0; i < n_batches / 2; ++i) { + ++it; + } + check_ellpack(); + } } TEST(SparsePageDMatrix, MultipleEllpackPages) { @@ -115,12 +161,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) { for (size_t i = 0; i < iterators.size(); ++i) { ASSERT_EQ((*iterators[i]).Impl()->gidx_buffer.HostVector(), gidx_buffers.at(i).HostVector()); - if (i != iterators.size() - 1) { - ASSERT_EQ(iterators[i].use_count(), 1); - } else { - // The last batch is still being held by sparse page DMatrix. - ASSERT_EQ(iterators[i].use_count(), 2); - } + ASSERT_EQ(iterators[i].use_count(), 1); } // make sure it's const and the caller can not modify the content of page.