[EM] Move prefetch in reset into the end of the iteration. (#10529)
This commit is contained in:
parent
e537b0969f
commit
9cb4c938da
@ -9,6 +9,9 @@ void GradientIndexPageSource::Fetch() {
|
|||||||
if (count_ != 0 && !sync_) {
|
if (count_ != 0 && !sync_) {
|
||||||
// source is initialized to be the 0th page during construction, so when count_ is 0
|
// source is initialized to be the 0th page during construction, so when count_ is 0
|
||||||
// there's no need to increment the source.
|
// there's no need to increment the source.
|
||||||
|
//
|
||||||
|
// The mixin doesn't sync the source if `sync_` is false, we need to sync it
|
||||||
|
// ourselves.
|
||||||
++(*source_);
|
++(*source_);
|
||||||
}
|
}
|
||||||
// This is not read from cache so we still need it to be synced with sparse page source.
|
// This is not read from cache so we still need it to be synced with sparse page source.
|
||||||
|
|||||||
@ -42,7 +42,7 @@ struct Cache {
|
|||||||
std::string name;
|
std::string name;
|
||||||
std::string format;
|
std::string format;
|
||||||
// offset into binary cache file.
|
// offset into binary cache file.
|
||||||
std::vector<std::uint64_t> offset;
|
std::vector<bst_idx_t> offset;
|
||||||
|
|
||||||
Cache(bool w, std::string n, std::string fmt, bool on_host)
|
Cache(bool w, std::string n, std::string fmt, bool on_host)
|
||||||
: written{w}, on_host{on_host}, name{std::move(n)}, format{std::move(fmt)} {
|
: written{w}, on_host{on_host}, name{std::move(n)}, format{std::move(fmt)} {
|
||||||
@ -61,7 +61,7 @@ struct Cache {
|
|||||||
/**
|
/**
|
||||||
* @brief Record a page with size of n_bytes.
|
* @brief Record a page with size of n_bytes.
|
||||||
*/
|
*/
|
||||||
void Push(std::size_t n_bytes) { offset.push_back(n_bytes); }
|
void Push(bst_idx_t n_bytes) { offset.push_back(n_bytes); }
|
||||||
/**
|
/**
|
||||||
* @brief Returns the view start and length for the i^th page.
|
* @brief Returns the view start and length for the i^th page.
|
||||||
*/
|
*/
|
||||||
@ -73,7 +73,7 @@ struct Cache {
|
|||||||
/**
|
/**
|
||||||
* @brief Get the number of bytes for the i^th page.
|
* @brief Get the number of bytes for the i^th page.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] std::uint64_t Bytes(std::size_t i) const { return offset.at(i + 1) - offset[i]; }
|
[[nodiscard]] bst_idx_t Bytes(std::size_t i) const { return offset.at(i + 1) - offset[i]; }
|
||||||
/**
|
/**
|
||||||
* @brief Call this once the write for the cache is complete.
|
* @brief Call this once the write for the cache is complete.
|
||||||
*/
|
*/
|
||||||
@ -218,7 +218,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
|
|
||||||
[[nodiscard]] bool ReadCache() {
|
[[nodiscard]] bool ReadCache() {
|
||||||
CHECK(!at_end_);
|
|
||||||
if (!cache_info_->written) {
|
if (!cache_info_->written) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -259,11 +258,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
return page;
|
return page;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
|
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
|
||||||
n_prefetch_batches)
|
n_prefetch_batches)
|
||||||
<< "Sparse DMatrix assumes forward iteration.";
|
<< "Sparse DMatrix assumes forward iteration.";
|
||||||
|
|
||||||
monitor_.Start("Wait");
|
monitor_.Start("Wait");
|
||||||
|
CHECK((*ring_)[count_].valid());
|
||||||
page_ = (*ring_)[count_].get();
|
page_ = (*ring_)[count_].get();
|
||||||
CHECK(!(*ring_)[count_].valid());
|
CHECK(!(*ring_)[count_].valid());
|
||||||
monitor_.Stop("Wait");
|
monitor_.Stop("Wait");
|
||||||
@ -331,12 +332,28 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
return at_end_;
|
return at_end_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Call this at the last iteration.
|
||||||
|
void EndIter() {
|
||||||
|
CHECK_EQ(this->cache_info_->offset.size(), this->n_batches_ + 1);
|
||||||
|
this->cache_info_->Commit();
|
||||||
|
if (this->n_batches_ != 0) {
|
||||||
|
CHECK_EQ(this->count_, this->n_batches_);
|
||||||
|
}
|
||||||
|
CHECK_GE(this->count_, 1);
|
||||||
|
this->count_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
virtual void Reset() {
|
virtual void Reset() {
|
||||||
TryLockGuard guard{single_threaded_};
|
TryLockGuard guard{single_threaded_};
|
||||||
at_end_ = false;
|
|
||||||
count_ = 0;
|
this->at_end_ = false;
|
||||||
// Pre-fetch for the next round of iterations.
|
auto cnt = this->count_;
|
||||||
this->Fetch();
|
this->count_ = 0;
|
||||||
|
if (cnt != 0) {
|
||||||
|
// The last iteration did not get to the end, clear the ring to start from 0.
|
||||||
|
this->ring_ = std::make_unique<Ring>();
|
||||||
|
this->Fetch();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -404,16 +421,11 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
|||||||
CHECK_LE(count_, n_batches_);
|
CHECK_LE(count_, n_batches_);
|
||||||
|
|
||||||
if (at_end_) {
|
if (at_end_) {
|
||||||
CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
|
this->EndIter();
|
||||||
cache_info_->Commit();
|
this->proxy_ = nullptr;
|
||||||
if (n_batches_ != 0) {
|
|
||||||
CHECK_EQ(count_, n_batches_);
|
|
||||||
}
|
|
||||||
CHECK_GE(count_, 1);
|
|
||||||
proxy_ = nullptr;
|
|
||||||
} else {
|
|
||||||
this->Fetch();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this->Fetch();
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -446,36 +458,46 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
|
|||||||
PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
||||||
bst_idx_t n_batches, std::shared_ptr<Cache> cache, bool sync)
|
bst_idx_t n_batches, std::shared_ptr<Cache> cache, bool sync)
|
||||||
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
|
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
|
||||||
|
// This function always operate on the source first, then the downstream. The downstream
|
||||||
|
// can assume the source to be ready.
|
||||||
[[nodiscard]] PageSourceIncMixIn& operator++() final {
|
[[nodiscard]] PageSourceIncMixIn& operator++() final {
|
||||||
TryLockGuard guard{this->single_threaded_};
|
TryLockGuard guard{this->single_threaded_};
|
||||||
|
// Increment the source.
|
||||||
if (sync_) {
|
if (sync_) {
|
||||||
++(*source_);
|
++(*source_);
|
||||||
}
|
}
|
||||||
|
// Increment self.
|
||||||
++this->count_;
|
++this->count_;
|
||||||
|
// Set at end.
|
||||||
this->at_end_ = this->count_ == this->n_batches_;
|
this->at_end_ = this->count_ == this->n_batches_;
|
||||||
|
|
||||||
if (this->at_end_) {
|
if (this->at_end_) {
|
||||||
this->cache_info_->Commit();
|
// If this is the first round of iterations, we have just built the binary cache
|
||||||
if (this->n_batches_ != 0) {
|
// from soruce. For a non-sync page type, the source hasn't been updated to the end
|
||||||
CHECK_EQ(this->count_, this->n_batches_);
|
// iteration yet due to skipped increment. We increment the source here and it will
|
||||||
|
// call the `EndIter` method itself.
|
||||||
|
bool src_need_inc = !sync_ && this->source_->Iter() != 0;
|
||||||
|
if (src_need_inc) {
|
||||||
|
CHECK_EQ(this->source_->Iter(), this->count_ - 1);
|
||||||
|
++(*source_);
|
||||||
|
}
|
||||||
|
this->EndIter();
|
||||||
|
|
||||||
|
if (src_need_inc) {
|
||||||
|
CHECK(this->cache_info_->written);
|
||||||
}
|
}
|
||||||
CHECK_GE(this->count_, 1);
|
|
||||||
} else {
|
|
||||||
this->Fetch();
|
|
||||||
}
|
}
|
||||||
|
this->Fetch();
|
||||||
|
|
||||||
if (sync_) {
|
if (sync_) {
|
||||||
|
// Sanity check.
|
||||||
CHECK_EQ(source_->Iter(), this->count_);
|
CHECK_EQ(source_->Iter(), this->count_);
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reset() final {
|
void Reset() final {
|
||||||
if (sync_) {
|
this->source_->Reset();
|
||||||
this->source_->Reset();
|
|
||||||
}
|
|
||||||
Super::Reset();
|
Super::Reset();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -118,7 +118,8 @@ TEST(SparsePageDMatrix, RetainSparsePage) {
|
|||||||
// Test GHistIndexMatrix can avoid loading sparse page after the initialization.
|
// Test GHistIndexMatrix can avoid loading sparse page after the initialization.
|
||||||
TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) {
|
TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) {
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix(
|
std::size_t n_batches = 6;
|
||||||
|
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix(
|
||||||
tmpdir.path + "/", true);
|
tmpdir.path + "/", true);
|
||||||
Context ctx;
|
Context ctx;
|
||||||
bst_bin_t n_bins{256};
|
bst_bin_t n_bins{256};
|
||||||
@ -171,12 +172,30 @@ TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) {
|
|||||||
// Restore the batch parameter by passing it in again through check_ghist
|
// Restore the batch parameter by passing it in again through check_ghist
|
||||||
check_ghist();
|
check_ghist();
|
||||||
}
|
}
|
||||||
|
|
||||||
// half the pages
|
// half the pages
|
||||||
auto it = Xy->GetBatches<SparsePage>(&ctx).begin();
|
{
|
||||||
for (std::int32_t i = 0; i < 3; ++i) {
|
auto it = Xy->GetBatches<SparsePage>(&ctx).begin();
|
||||||
++it;
|
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
check_ghist();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto it = Xy->GetBatches<GHistIndexMatrix>(&ctx, batch_param).begin();
|
||||||
|
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
check_ghist();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
BatchParam regen{n_bins, common::Span{hess.data(), hess.size()}, true};
|
||||||
|
auto it = Xy->GetBatches<GHistIndexMatrix>(&ctx, regen).begin();
|
||||||
|
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
check_ghist();
|
||||||
}
|
}
|
||||||
check_ghist();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, MetaInfo) {
|
TEST(SparsePageDMatrix, MetaInfo) {
|
||||||
|
|||||||
@ -41,31 +41,77 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
|||||||
TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
|
TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
|
||||||
// Test Ellpack can avoid loading sparse page after the initialization.
|
// Test Ellpack can avoid loading sparse page after the initialization.
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix(
|
std::size_t n_batches = 6;
|
||||||
|
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix(
|
||||||
tmpdir.path + "/", true);
|
tmpdir.path + "/", true);
|
||||||
auto ctx = MakeCUDACtx(0);
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
auto cpu = ctx.MakeCPU();
|
||||||
bst_bin_t n_bins{256};
|
bst_bin_t n_bins{256};
|
||||||
double sparse_thresh{0.8};
|
double sparse_thresh{0.8};
|
||||||
BatchParam batch_param{n_bins, sparse_thresh};
|
BatchParam batch_param{n_bins, sparse_thresh};
|
||||||
|
|
||||||
std::int32_t k = 0;
|
auto check_ellpack = [&]() {
|
||||||
for (auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
std::int32_t k = 0;
|
||||||
auto impl = page.Impl();
|
for (auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
ASSERT_EQ(page.Size(), 30);
|
auto impl = page.Impl();
|
||||||
ASSERT_EQ(k, impl->base_rowid);
|
ASSERT_EQ(page.Size(), 30);
|
||||||
k += page.Size();
|
ASSERT_EQ(k, impl->base_rowid);
|
||||||
}
|
k += page.Size();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
|
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
|
||||||
CHECK(casted);
|
CHECK(casted);
|
||||||
|
check_ellpack();
|
||||||
|
|
||||||
// Make the number of fetches don't change (no new fetch)
|
// Make the number of fetches don't change (no new fetch)
|
||||||
auto n_fetches = casted->SparsePageFetchCount();
|
auto n_fetches = casted->SparsePageFetchCount();
|
||||||
for (std::int32_t i = 0; i < 3; ++i) {
|
for (std::size_t i = 0; i < 3; ++i) {
|
||||||
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
}
|
}
|
||||||
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
|
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
|
||||||
ASSERT_EQ(casted->SparsePageFetchCount(), n_fetches);
|
ASSERT_EQ(casted->SparsePageFetchCount(), n_fetches);
|
||||||
}
|
}
|
||||||
|
check_ellpack();
|
||||||
|
|
||||||
|
dh::device_vector<float> hess(Xy->Info().num_row_, 1.0f);
|
||||||
|
for (std::size_t i = 0; i < 4; ++i) {
|
||||||
|
for ([[maybe_unused]] auto const& page : Xy->GetBatches<SparsePage>(&ctx)) {
|
||||||
|
}
|
||||||
|
for ([[maybe_unused]] auto const& page : Xy->GetBatches<SortedCSCPage>(&cpu)) {
|
||||||
|
}
|
||||||
|
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
|
}
|
||||||
|
// Approx tree method pages
|
||||||
|
{
|
||||||
|
BatchParam regen{n_bins, dh::ToSpan(hess), false};
|
||||||
|
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, regen)) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
BatchParam regen{n_bins, dh::ToSpan(hess), true};
|
||||||
|
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, regen)) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
check_ellpack();
|
||||||
|
}
|
||||||
|
|
||||||
|
// half the pages
|
||||||
|
{
|
||||||
|
auto it = Xy->GetBatches<SparsePage>(&ctx).begin();
|
||||||
|
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
check_ellpack();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto it = Xy->GetBatches<EllpackPage>(&ctx, batch_param).begin();
|
||||||
|
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
check_ellpack();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||||
@ -115,12 +161,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
|
|||||||
|
|
||||||
for (size_t i = 0; i < iterators.size(); ++i) {
|
for (size_t i = 0; i < iterators.size(); ++i) {
|
||||||
ASSERT_EQ((*iterators[i]).Impl()->gidx_buffer.HostVector(), gidx_buffers.at(i).HostVector());
|
ASSERT_EQ((*iterators[i]).Impl()->gidx_buffer.HostVector(), gidx_buffers.at(i).HostVector());
|
||||||
if (i != iterators.size() - 1) {
|
ASSERT_EQ(iterators[i].use_count(), 1);
|
||||||
ASSERT_EQ(iterators[i].use_count(), 1);
|
|
||||||
} else {
|
|
||||||
// The last batch is still being held by sparse page DMatrix.
|
|
||||||
ASSERT_EQ(iterators[i].use_count(), 2);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure it's const and the caller can not modify the content of page.
|
// make sure it's const and the caller can not modify the content of page.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user