[EM] Move prefetch in reset into the end of the iteration. (#10529)

This commit is contained in:
Jiaming Yuan
2024-07-03 03:48:18 +08:00
committed by GitHub
parent e537b0969f
commit 9cb4c938da
4 changed files with 133 additions and 48 deletions

View File

@@ -9,6 +9,9 @@ void GradientIndexPageSource::Fetch() {
if (count_ != 0 && !sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
//
// The mixin doesn't sync the source if `sync_` is false, we need to sync it
// ourselves.
++(*source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.

View File

@@ -42,7 +42,7 @@ struct Cache {
std::string name;
std::string format;
// offset into binary cache file.
std::vector<std::uint64_t> offset;
std::vector<bst_idx_t> offset;
Cache(bool w, std::string n, std::string fmt, bool on_host)
: written{w}, on_host{on_host}, name{std::move(n)}, format{std::move(fmt)} {
@@ -61,7 +61,7 @@ struct Cache {
/**
* @brief Record a page with size of n_bytes.
*/
void Push(std::size_t n_bytes) { offset.push_back(n_bytes); }
void Push(bst_idx_t n_bytes) { offset.push_back(n_bytes); }
/**
* @brief Returns the view start and length for the i^th page.
*/
@@ -73,7 +73,7 @@ struct Cache {
/**
* @brief Get the number of bytes for the i^th page.
*/
[[nodiscard]] std::uint64_t Bytes(std::size_t i) const { return offset.at(i + 1) - offset[i]; }
[[nodiscard]] bst_idx_t Bytes(std::size_t i) const { return offset.at(i + 1) - offset[i]; }
/**
* @brief Call this once the write for the cache is complete.
*/
@@ -218,7 +218,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
common::Monitor monitor_;
[[nodiscard]] bool ReadCache() {
CHECK(!at_end_);
if (!cache_info_->written) {
return false;
}
@@ -259,11 +258,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
return page;
});
}
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
n_prefetch_batches)
<< "Sparse DMatrix assumes forward iteration.";
monitor_.Start("Wait");
CHECK((*ring_)[count_].valid());
page_ = (*ring_)[count_].get();
CHECK(!(*ring_)[count_].valid());
monitor_.Stop("Wait");
@@ -331,12 +332,28 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
return at_end_;
}
// Call this at the last iteration.
void EndIter() {
CHECK_EQ(this->cache_info_->offset.size(), this->n_batches_ + 1);
this->cache_info_->Commit();
if (this->n_batches_ != 0) {
CHECK_EQ(this->count_, this->n_batches_);
}
CHECK_GE(this->count_, 1);
this->count_ = 0;
}
virtual void Reset() {
TryLockGuard guard{single_threaded_};
at_end_ = false;
count_ = 0;
// Pre-fetch for the next round of iterations.
this->Fetch();
this->at_end_ = false;
auto cnt = this->count_;
this->count_ = 0;
if (cnt != 0) {
// The last iteration did not get to the end, clear the ring to start from 0.
this->ring_ = std::make_unique<Ring>();
this->Fetch();
}
}
};
@@ -404,16 +421,11 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
CHECK_LE(count_, n_batches_);
if (at_end_) {
CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
cache_info_->Commit();
if (n_batches_ != 0) {
CHECK_EQ(count_, n_batches_);
}
CHECK_GE(count_, 1);
proxy_ = nullptr;
} else {
this->Fetch();
this->EndIter();
this->proxy_ = nullptr;
}
this->Fetch();
return *this;
}
@@ -446,36 +458,46 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features,
bst_idx_t n_batches, std::shared_ptr<Cache> cache, bool sync)
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
// This function always operate on the source first, then the downstream. The downstream
// can assume the source to be ready.
[[nodiscard]] PageSourceIncMixIn& operator++() final {
TryLockGuard guard{this->single_threaded_};
// Increment the source.
if (sync_) {
++(*source_);
}
// Increment self.
++this->count_;
// Set at end.
this->at_end_ = this->count_ == this->n_batches_;
if (this->at_end_) {
this->cache_info_->Commit();
if (this->n_batches_ != 0) {
CHECK_EQ(this->count_, this->n_batches_);
// If this is the first round of iterations, we have just built the binary cache
// from soruce. For a non-sync page type, the source hasn't been updated to the end
// iteration yet due to skipped increment. We increment the source here and it will
// call the `EndIter` method itself.
bool src_need_inc = !sync_ && this->source_->Iter() != 0;
if (src_need_inc) {
CHECK_EQ(this->source_->Iter(), this->count_ - 1);
++(*source_);
}
this->EndIter();
if (src_need_inc) {
CHECK(this->cache_info_->written);
}
CHECK_GE(this->count_, 1);
} else {
this->Fetch();
}
this->Fetch();
if (sync_) {
// Sanity check.
CHECK_EQ(source_->Iter(), this->count_);
}
return *this;
}
void Reset() final {
if (sync_) {
this->source_->Reset();
}
this->source_->Reset();
Super::Reset();
}
};