[EM] Move prefetch in reset into the end of the iteration. (#10529)

This commit is contained in:
Jiaming Yuan
2024-07-03 03:48:18 +08:00
committed by GitHub
parent e537b0969f
commit 9cb4c938da
4 changed files with 133 additions and 48 deletions

View File

@@ -118,7 +118,8 @@ TEST(SparsePageDMatrix, RetainSparsePage) {
// Test GHistIndexMatrix can avoid loading sparse page after the initialization.
TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) {
dmlc::TemporaryDirectory tmpdir;
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix(
std::size_t n_batches = 6;
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix(
tmpdir.path + "/", true);
Context ctx;
bst_bin_t n_bins{256};
@@ -171,12 +172,30 @@ TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) {
// Restore the batch parameter by passing it in again through check_ghist
check_ghist();
}
// half the pages
auto it = Xy->GetBatches<SparsePage>(&ctx).begin();
for (std::int32_t i = 0; i < 3; ++i) {
++it;
{
auto it = Xy->GetBatches<SparsePage>(&ctx).begin();
for (std::size_t i = 0; i < n_batches / 2; ++i) {
++it;
}
check_ghist();
}
{
auto it = Xy->GetBatches<GHistIndexMatrix>(&ctx, batch_param).begin();
for (std::size_t i = 0; i < n_batches / 2; ++i) {
++it;
}
check_ghist();
}
{
BatchParam regen{n_bins, common::Span{hess.data(), hess.size()}, true};
auto it = Xy->GetBatches<GHistIndexMatrix>(&ctx, regen).begin();
for (std::size_t i = 0; i < n_batches / 2; ++i) {
++it;
}
check_ghist();
}
check_ghist();
}
TEST(SparsePageDMatrix, MetaInfo) {

View File

@@ -41,31 +41,77 @@ TEST(SparsePageDMatrix, EllpackPage) {
TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
// Test Ellpack can avoid loading sparse page after the initialization.
dmlc::TemporaryDirectory tmpdir;
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix(
std::size_t n_batches = 6;
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix(
tmpdir.path + "/", true);
auto ctx = MakeCUDACtx(0);
auto cpu = ctx.MakeCPU();
bst_bin_t n_bins{256};
double sparse_thresh{0.8};
BatchParam batch_param{n_bins, sparse_thresh};
std::int32_t k = 0;
for (auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
auto impl = page.Impl();
ASSERT_EQ(page.Size(), 30);
ASSERT_EQ(k, impl->base_rowid);
k += page.Size();
}
auto check_ellpack = [&]() {
std::int32_t k = 0;
for (auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
auto impl = page.Impl();
ASSERT_EQ(page.Size(), 30);
ASSERT_EQ(k, impl->base_rowid);
k += page.Size();
}
};
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
CHECK(casted);
check_ellpack();
// Make the number of fetches don't change (no new fetch)
auto n_fetches = casted->SparsePageFetchCount();
for (std::int32_t i = 0; i < 3; ++i) {
for (std::size_t i = 0; i < 3; ++i) {
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
}
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
ASSERT_EQ(casted->SparsePageFetchCount(), n_fetches);
}
check_ellpack();
dh::device_vector<float> hess(Xy->Info().num_row_, 1.0f);
for (std::size_t i = 0; i < 4; ++i) {
for ([[maybe_unused]] auto const& page : Xy->GetBatches<SparsePage>(&ctx)) {
}
for ([[maybe_unused]] auto const& page : Xy->GetBatches<SortedCSCPage>(&cpu)) {
}
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
}
// Approx tree method pages
{
BatchParam regen{n_bins, dh::ToSpan(hess), false};
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, regen)) {
}
}
{
BatchParam regen{n_bins, dh::ToSpan(hess), true};
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, regen)) {
}
}
check_ellpack();
}
// half the pages
{
auto it = Xy->GetBatches<SparsePage>(&ctx).begin();
for (std::size_t i = 0; i < n_batches / 2; ++i) {
++it;
}
check_ellpack();
}
{
auto it = Xy->GetBatches<EllpackPage>(&ctx, batch_param).begin();
for (std::size_t i = 0; i < n_batches / 2; ++i) {
++it;
}
check_ellpack();
}
}
TEST(SparsePageDMatrix, MultipleEllpackPages) {
@@ -115,12 +161,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
for (size_t i = 0; i < iterators.size(); ++i) {
ASSERT_EQ((*iterators[i]).Impl()->gidx_buffer.HostVector(), gidx_buffers.at(i).HostVector());
if (i != iterators.size() - 1) {
ASSERT_EQ(iterators[i].use_count(), 1);
} else {
// The last batch is still being held by sparse page DMatrix.
ASSERT_EQ(iterators[i].use_count(), 2);
}
ASSERT_EQ(iterators[i].use_count(), 1);
}
// make sure it's const and the caller can not modify the content of page.