[EM] Move prefetch in reset into the end of the iteration. (#10529)
This commit is contained in:
@@ -118,7 +118,8 @@ TEST(SparsePageDMatrix, RetainSparsePage) {
|
||||
// Test GHistIndexMatrix can avoid loading sparse page after the initialization.
|
||||
TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix(
|
||||
std::size_t n_batches = 6;
|
||||
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix(
|
||||
tmpdir.path + "/", true);
|
||||
Context ctx;
|
||||
bst_bin_t n_bins{256};
|
||||
@@ -171,12 +172,30 @@ TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) {
|
||||
// Restore the batch parameter by passing it in again through check_ghist
|
||||
check_ghist();
|
||||
}
|
||||
|
||||
// half the pages
|
||||
auto it = Xy->GetBatches<SparsePage>(&ctx).begin();
|
||||
for (std::int32_t i = 0; i < 3; ++i) {
|
||||
++it;
|
||||
{
|
||||
auto it = Xy->GetBatches<SparsePage>(&ctx).begin();
|
||||
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||
++it;
|
||||
}
|
||||
check_ghist();
|
||||
}
|
||||
{
|
||||
auto it = Xy->GetBatches<GHistIndexMatrix>(&ctx, batch_param).begin();
|
||||
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||
++it;
|
||||
}
|
||||
check_ghist();
|
||||
}
|
||||
{
|
||||
BatchParam regen{n_bins, common::Span{hess.data(), hess.size()}, true};
|
||||
auto it = Xy->GetBatches<GHistIndexMatrix>(&ctx, regen).begin();
|
||||
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||
++it;
|
||||
}
|
||||
check_ghist();
|
||||
}
|
||||
check_ghist();
|
||||
}
|
||||
|
||||
TEST(SparsePageDMatrix, MetaInfo) {
|
||||
|
||||
@@ -41,31 +41,77 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
||||
TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
|
||||
// Test Ellpack can avoid loading sparse page after the initialization.
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix(
|
||||
std::size_t n_batches = 6;
|
||||
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix(
|
||||
tmpdir.path + "/", true);
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto cpu = ctx.MakeCPU();
|
||||
bst_bin_t n_bins{256};
|
||||
double sparse_thresh{0.8};
|
||||
BatchParam batch_param{n_bins, sparse_thresh};
|
||||
|
||||
std::int32_t k = 0;
|
||||
for (auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||
auto impl = page.Impl();
|
||||
ASSERT_EQ(page.Size(), 30);
|
||||
ASSERT_EQ(k, impl->base_rowid);
|
||||
k += page.Size();
|
||||
}
|
||||
auto check_ellpack = [&]() {
|
||||
std::int32_t k = 0;
|
||||
for (auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||
auto impl = page.Impl();
|
||||
ASSERT_EQ(page.Size(), 30);
|
||||
ASSERT_EQ(k, impl->base_rowid);
|
||||
k += page.Size();
|
||||
}
|
||||
};
|
||||
|
||||
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
|
||||
CHECK(casted);
|
||||
check_ellpack();
|
||||
|
||||
// Make the number of fetches don't change (no new fetch)
|
||||
auto n_fetches = casted->SparsePageFetchCount();
|
||||
for (std::int32_t i = 0; i < 3; ++i) {
|
||||
for (std::size_t i = 0; i < 3; ++i) {
|
||||
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||
}
|
||||
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
|
||||
ASSERT_EQ(casted->SparsePageFetchCount(), n_fetches);
|
||||
}
|
||||
check_ellpack();
|
||||
|
||||
dh::device_vector<float> hess(Xy->Info().num_row_, 1.0f);
|
||||
for (std::size_t i = 0; i < 4; ++i) {
|
||||
for ([[maybe_unused]] auto const& page : Xy->GetBatches<SparsePage>(&ctx)) {
|
||||
}
|
||||
for ([[maybe_unused]] auto const& page : Xy->GetBatches<SortedCSCPage>(&cpu)) {
|
||||
}
|
||||
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||
}
|
||||
// Approx tree method pages
|
||||
{
|
||||
BatchParam regen{n_bins, dh::ToSpan(hess), false};
|
||||
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, regen)) {
|
||||
}
|
||||
}
|
||||
{
|
||||
BatchParam regen{n_bins, dh::ToSpan(hess), true};
|
||||
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, regen)) {
|
||||
}
|
||||
}
|
||||
|
||||
check_ellpack();
|
||||
}
|
||||
|
||||
// half the pages
|
||||
{
|
||||
auto it = Xy->GetBatches<SparsePage>(&ctx).begin();
|
||||
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||
++it;
|
||||
}
|
||||
check_ellpack();
|
||||
}
|
||||
{
|
||||
auto it = Xy->GetBatches<EllpackPage>(&ctx, batch_param).begin();
|
||||
for (std::size_t i = 0; i < n_batches / 2; ++i) {
|
||||
++it;
|
||||
}
|
||||
check_ellpack();
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||
@@ -115,12 +161,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
|
||||
|
||||
for (size_t i = 0; i < iterators.size(); ++i) {
|
||||
ASSERT_EQ((*iterators[i]).Impl()->gidx_buffer.HostVector(), gidx_buffers.at(i).HostVector());
|
||||
if (i != iterators.size() - 1) {
|
||||
ASSERT_EQ(iterators[i].use_count(), 1);
|
||||
} else {
|
||||
// The last batch is still being held by sparse page DMatrix.
|
||||
ASSERT_EQ(iterators[i].use_count(), 2);
|
||||
}
|
||||
ASSERT_EQ(iterators[i].use_count(), 1);
|
||||
}
|
||||
|
||||
// make sure it's const and the caller can not modify the content of page.
|
||||
|
||||
Reference in New Issue
Block a user