diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 87d3be1fe..0821ce648 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -239,42 +239,52 @@ struct Entry { }; /** - * \brief Parameters for constructing histogram index batches. + * @brief Parameters for constructing histogram index batches. */ struct BatchParam { /** - * \brief Maximum number of bins per feature for histograms. + * @brief Maximum number of bins per feature for histograms. */ bst_bin_t max_bin{0}; /** - * \brief Hessian, used for sketching with future approx implementation. + * @brief Hessian, used for sketching with future approx implementation. */ common::Span hess; /** - * \brief Whether should we force DMatrix to regenerate the batch. Only used for + * @brief Whether should we force DMatrix to regenerate the batch. Only used for * GHistIndex. */ bool regen{false}; /** - * \brief Forbid regenerating the gradient index. Used for internal validation. + * @brief Forbid regenerating the gradient index. Used for internal validation. */ bool forbid_regen{false}; /** - * \brief Parameter used to generate column matrix for hist. + * @brief Parameter used to generate column matrix for hist. */ double sparse_thresh{std::numeric_limits::quiet_NaN()}; + /** + * @brief Used for GPU external memory. Whether to copy the data into device. + * + * This affects only the current round of iteration. + */ + bool prefetch_copy{true}; + /** + * @brief The number of batches to pre-fetch for external memory. + */ + std::int32_t n_prefetch_batches{3}; /** - * \brief Exact or others that don't need histogram. + * @brief Exact or others that don't need histogram. */ BatchParam() = default; /** - * \brief Used by the hist tree method. + * @brief Used by the hist tree method. */ BatchParam(bst_bin_t max_bin, double sparse_thresh) : max_bin{max_bin}, sparse_thresh{sparse_thresh} {} /** - * \brief Used by the approx tree method. + * @brief Used by the approx tree method. * * Get batch with sketch weighted by hessian. The batch will be regenerated if the * span is changed, so caller should keep the span for each iteration. @@ -295,7 +305,7 @@ struct BatchParam { } [[nodiscard]] bool Initialized() const { return max_bin != 0; } /** - * \brief Make a copy of self for DMatrix to describe how its existing index was generated. + * @brief Make a copy of self for DMatrix to describe how its existing index was generated. */ [[nodiscard]] BatchParam MakeCache() const { auto p = *this; diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index e17d72000..4f39497b8 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -60,7 +60,7 @@ template RET_IF_NOT(fi->Read(&impl->is_dense)); RET_IF_NOT(fi->Read(&impl->row_stride)); - if (has_hmm_ats_) { + if (has_hmm_ats_ && !this->param_.prefetch_copy) { RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer)); } else { RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer)); @@ -95,7 +95,7 @@ template CHECK(this->cuts_->cut_values_.DeviceCanRead()); impl->SetCuts(this->cuts_); - fi->Read(page); + fi->Read(page, this->param_.prefetch_copy); dh::DefaultStream().Sync(); return true; diff --git a/src/data/ellpack_page_raw_format.h b/src/data/ellpack_page_raw_format.h index e2761c73f..9be2c50cf 100644 --- a/src/data/ellpack_page_raw_format.h +++ b/src/data/ellpack_page_raw_format.h @@ -26,13 +26,17 @@ class EllpackHostCacheStream; class EllpackPageRawFormat : public SparsePageFormat { std::shared_ptr cuts_; DeviceOrd device_; + BatchParam param_; // Supports CUDA HMM or ATS bool has_hmm_ats_{false}; public: explicit EllpackPageRawFormat(std::shared_ptr cuts, DeviceOrd device, - bool has_hmm_ats) - : cuts_{std::move(cuts)}, device_{device}, has_hmm_ats_{has_hmm_ats} {} + BatchParam param, bool has_hmm_ats) + : cuts_{std::move(cuts)}, + device_{device}, + param_{std::move(param)}, + has_hmm_ats_{has_hmm_ats} {} [[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override; [[nodiscard]] std::size_t Write(const EllpackPage& page, common::AlignedFileWriteStream* fo) override; diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 4c49dbc9a..980fa154b 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -11,7 +11,6 @@ #include "../common/common.h" // for safe_cuda #include "../common/ref_resource_view.cuh" -#include "../common/cuda_pinned_allocator.h" // for pinned_allocator #include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream #include "../common/resource.cuh" // for PrivateCudaMmapConstStream #include "ellpack_page.cuh" // for EllpackPageImpl @@ -19,7 +18,6 @@ #include "ellpack_page_source.h" #include "proxy_dmatrix.cuh" // for Dispatch #include "xgboost/base.h" // for bst_idx_t -#include "../common/cuda_rt_utils.h" // for NvtxScopedRange #include "../common/transform_iterator.h" // for MakeIndexTransformIter namespace xgboost::data { @@ -91,14 +89,20 @@ class EllpackHostCacheStreamImpl { ptr_ += 1; } - void Read(EllpackPage* out) const { + void Read(EllpackPage* out, bool prefetch_copy) const { auto page = this->cache_->Get(ptr_); auto impl = out->Impl(); - impl->gidx_buffer = - common::MakeFixedVecWithCudaMalloc(page->gidx_buffer.size()); - dh::safe_cuda(cudaMemcpyAsync(impl->gidx_buffer.data(), page->gidx_buffer.data(), - page->gidx_buffer.size_bytes(), cudaMemcpyDefault)); + if (prefetch_copy) { + impl->gidx_buffer = + common::MakeFixedVecWithCudaMalloc(page->gidx_buffer.size()); + dh::safe_cuda(cudaMemcpyAsync(impl->gidx_buffer.data(), page->gidx_buffer.data(), + page->gidx_buffer.size_bytes(), cudaMemcpyDefault)); + } else { + auto res = page->gidx_buffer.Resource(); + impl->gidx_buffer = common::RefResourceView{ + res->DataAs(), page->gidx_buffer.size(), res}; + } impl->n_rows = page->Size(); impl->is_dense = page->IsDense(); @@ -120,7 +124,9 @@ std::shared_ptr EllpackHostCacheStream::Share() { return p_imp void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); } -void EllpackHostCacheStream::Read(EllpackPage* page) const { this->p_impl_->Read(page); } +void EllpackHostCacheStream::Read(EllpackPage* page, bool prefetch_copy) const { + this->p_impl_->Read(page, prefetch_copy); +} void EllpackHostCacheStream::Write(EllpackPage const& page) { this->p_impl_->Write(page); } @@ -162,8 +168,9 @@ EllpackCacheStreamPolicy::CreateWriter(StringV template std::unique_ptr< typename EllpackCacheStreamPolicy::ReaderT> -EllpackCacheStreamPolicy::CreateReader( - StringView name, std::uint64_t offset, std::uint64_t length) const; +EllpackCacheStreamPolicy::CreateReader(StringView name, + bst_idx_t offset, + bst_idx_t length) const; /** * EllpackMmapStreamPolicy @@ -233,6 +240,7 @@ void ExtEllpackPageSourceImpl::Fetch() { ++(*this->source_); CHECK_GE(this->source_->Iter(), 1); cuda_impl::Dispatch(proxy_, [this](auto const& value) { + CHECK(this->proxy_->Ctx()->IsCUDA()) << "All batches must use the same device type."; proxy_->Info().feature_types.SetDevice(dh::GetDevice(this->ctx_)); auto d_feature_types = proxy_->Info().feature_types.ConstDeviceSpan(); auto n_samples = value.NumRows(); diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index 61f94a262..3c121b13c 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -53,7 +53,7 @@ class EllpackHostCacheStream { void Seek(bst_idx_t offset_bytes); - void Read(EllpackPage* page) const; + void Read(EllpackPage* page, bool prefetch_copy) const; void Write(EllpackPage const& page); }; @@ -71,9 +71,9 @@ class EllpackFormatPolicy { // For testing with the HMM flag. explicit EllpackFormatPolicy(bool has_hmm) : has_hmm_{has_hmm} {} - [[nodiscard]] auto CreatePageFormat() const { + [[nodiscard]] auto CreatePageFormat(BatchParam const& param) const { CHECK_EQ(cuts_->cut_values_.Device(), device_); - std::unique_ptr fmt{new EllpackPageRawFormat{cuts_, device_, has_hmm_}}; + std::unique_ptr fmt{new EllpackPageRawFormat{cuts_, device_, param, has_hmm_}}; return fmt; } diff --git a/src/data/extmem_quantile_dmatrix.cc b/src/data/extmem_quantile_dmatrix.cc index 96e88a55a..0bdab8f02 100644 --- a/src/data/extmem_quantile_dmatrix.cc +++ b/src/data/extmem_quantile_dmatrix.cc @@ -66,6 +66,8 @@ void ExtMemQuantileDMatrix::InitFromCPU( Context const *ctx, std::shared_ptr> iter, DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr ref) { + xgboost_NVTX_FN_RANGE(); + auto proxy = MakeProxy(proxy_handle); CHECK(proxy); @@ -118,7 +120,7 @@ BatchSet ExtMemQuantileDMatrix::GetGradientIndex(Context const } CHECK(this->ghist_index_source_); - this->ghist_index_source_->Reset(); + this->ghist_index_source_->Reset(param); if (!std::isnan(param.sparse_thresh) && param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) { diff --git a/src/data/extmem_quantile_dmatrix.cu b/src/data/extmem_quantile_dmatrix.cu index f7f033e95..ea3f12c2e 100644 --- a/src/data/extmem_quantile_dmatrix.cu +++ b/src/data/extmem_quantile_dmatrix.cu @@ -11,6 +11,7 @@ #include "proxy_dmatrix.h" // for DataIterProxy #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for BatchParam +#include "../common/cuda_rt_utils.h" namespace xgboost::data { void ExtMemQuantileDMatrix::InitFromCUDA( @@ -78,9 +79,9 @@ BatchSet ExtMemQuantileDMatrix::GetEllpackBatches(Context const *, } std::visit( - [this](auto &&ptr) { + [this, param](auto &&ptr) { CHECK(ptr); - ptr->Reset(); + ptr->Reset(param); }, this->ellpack_page_source_); diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index 88b83433f..d46f044ae 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -37,6 +37,7 @@ void ExtGradientIndexPageSource::Fetch() { CHECK_GE(source_->Iter(), 1); CHECK_NE(cuts_.Values().size(), 0); HostAdapterDispatch(proxy_, [this](auto const& value) { + CHECK(this->proxy_->Ctx()->IsCPU()) << "All batches must use the same device type."; // This does three things: // - Generate CSR matrix for gradient index. // - Generate the column matrix for gradient index. diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h index 535e86670..e090e0744 100644 --- a/src/data/gradient_index_page_source.h +++ b/src/data/gradient_index_page_source.h @@ -31,7 +31,7 @@ class GHistIndexFormatPolicy { using FormatT = SparsePageFormat; public: - [[nodiscard]] auto CreatePageFormat() const { + [[nodiscard]] auto CreatePageFormat(BatchParam const&) const { std::unique_ptr fmt{new GHistIndexRawFormat{cuts_}}; return fmt; } diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index eb9da871b..7cabfbd14 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -82,7 +82,7 @@ void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) { // release the iterator and data. if (cache_info_.at(id)->written) { CHECK(sparse_page_source_); - sparse_page_source_->Reset(); + sparse_page_source_->Reset({}); return; } @@ -114,7 +114,7 @@ BatchSet SparsePageDMatrix::GetColumnBatches(Context const *ctx) { std::make_shared(this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), sparse_page_source_); } else { - column_source_->Reset(); + column_source_->Reset({}); } return BatchSet{BatchIterator{this->column_source_}}; } @@ -129,7 +129,7 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches(Context const this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), sparse_page_source_); } else { - sorted_column_source_->Reset(); + sorted_column_source_->Reset({}); } return BatchSet{BatchIterator{this->sorted_column_source_}}; } @@ -161,7 +161,7 @@ BatchSet SparsePageDMatrix::GetGradientIndex(Context const *ct param, std::move(cuts), this->IsDense(), ft, sparse_page_source_)); } else { CHECK(ghist_index_source_); - ghist_index_source_->Reset(); + ghist_index_source_->Reset(param); } return BatchSet{BatchIterator{this->ghist_index_source_}}; } diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index bcf7d4360..898069533 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -61,7 +61,7 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, ellpack_page_source_); } else { CHECK(sparse_page_source_); - std::visit([&](auto&& ptr) { ptr->Reset(); }, this->ellpack_page_source_); + std::visit([&](auto&& ptr) { ptr->Reset(param); }, this->ellpack_page_source_); } auto batch_set = diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 5ec93a311..e014dea0c 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -204,7 +204,7 @@ class DefaultFormatPolicy { using FormatT = SparsePageFormat; public: - auto CreatePageFormat() const { + auto CreatePageFormat(BatchParam const&) const { std::unique_ptr fmt{::xgboost::data::CreatePageFormat("raw")}; return fmt; } @@ -245,6 +245,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol std::uint32_t count_{0}; // Total number of batches. bst_idx_t n_batches_{0}; + // How we pre-fetch the data. + BatchParam param_; std::shared_ptr cache_info_; @@ -267,12 +269,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol } // An heuristic for number of pre-fetched batches. We can make it part of BatchParam // to let user adjust number of pre-fetched batches when needed. - std::int32_t constexpr kPrefetches = 3; - std::int32_t n_prefetches = std::min(nthreads_, kPrefetches); + std::int32_t n_prefetches = std::min(nthreads_, this->param_.n_prefetch_batches); n_prefetches = std::max(n_prefetches, 1); std::int32_t n_prefetch_batches = std::min(static_cast(n_prefetches), n_batches_); - CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; - CHECK_LE(n_prefetch_batches, kPrefetches); + CHECK_GT(n_prefetch_batches, 0); + CHECK_LE(n_prefetch_batches, this->param_.n_prefetch_batches); std::size_t fetch_it = count_; exce_.Rethrow(); @@ -287,7 +288,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] { auto page = std::make_shared(); this->exce_.Run([&] { - std::unique_ptr fmt{this->CreatePageFormat()}; + std::unique_ptr fmt{ + this->CreatePageFormat(this->param_)}; auto name = self->cache_info_->ShardName(); auto [offset, length] = self->cache_info_->View(fetch_it); std::unique_ptr fi{ @@ -317,7 +319,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol CHECK(!cache_info_->written); common::Timer timer; timer.Start(); - auto fmt{this->CreatePageFormat()}; + auto fmt{this->CreatePageFormat(this->param_)}; auto name = cache_info_->ShardName(); std::unique_ptr fo{ @@ -382,13 +384,16 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol this->count_ = 0; } - virtual void Reset() { + virtual void Reset(BatchParam const& param) { TryLockGuard guard{single_threaded_}; this->at_end_ = false; auto cnt = this->count_; this->count_ = 0; - if (cnt != 0) { + bool changed = this->param_.n_prefetch_batches != param.n_prefetch_batches; + this->param_ = param; + + if (cnt != 0 || changed) { // The last iteration did not get to the end, clear the ring to start from 0. this->ring_ = std::make_unique(); this->Fetch(); @@ -468,12 +473,12 @@ class SparsePageSource : public SparsePageSourceImpl { return *this; } - void Reset() override { + void Reset(BatchParam const& param) override { if (proxy_) { TryLockGuard guard{single_threaded_}; iter_.Reset(); } - SparsePageSourceImpl::Reset(); + SparsePageSourceImpl::Reset(param); TryLockGuard guard{single_threaded_}; base_row_id_ = 0; @@ -535,9 +540,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { return *this; } - void Reset() final { - this->source_->Reset(); - Super::Reset(); + void Reset(BatchParam const& param) final { + this->source_->Reset(param); + Super::Reset(param); } }; @@ -626,11 +631,11 @@ class ExtQantileSourceMixin : public SparsePageSourceImpl return *this; } - void Reset() final { + void Reset(BatchParam const& param) final { if (this->source_) { this->source_->Reset(); } - Super::Reset(); + Super::Reset(param); } }; } // namespace xgboost::data diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index f4224a30e..31f93d18a 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -119,8 +119,11 @@ struct DeviceSplitCandidate { }; namespace cuda_impl { -inline BatchParam HistBatch(TrainParam const& param) { - return {param.max_bin, TrainParam::DftSparseThreshold()}; +inline BatchParam HistBatch(TrainParam const& param, bool prefetch_copy = true) { + auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()}; + p.prefetch_copy = prefetch_copy; + p.n_prefetch_batches = 1; + return p; } inline BatchParam HistBatch(bst_bin_t max_bin) { diff --git a/tests/cpp/data/test_ellpack_page_raw_format.cu b/tests/cpp/data/test_ellpack_page_raw_format.cu index 05aec905a..4ac4f9c70 100644 --- a/tests/cpp/data/test_ellpack_page_raw_format.cu +++ b/tests/cpp/data/test_ellpack_page_raw_format.cu @@ -13,77 +13,84 @@ namespace xgboost::data { namespace { -template -void TestEllpackPageRawFormat(FormatStreamPolicy *p_policy) { - auto &policy = *p_policy; - Context ctx{MakeCUDACtx(0)}; - auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; +class TestEllpackPageRawFormat : public ::testing::TestWithParam { + public: + template + void Run(FormatStreamPolicy *p_policy, bool prefetch_copy) { + auto &policy = *p_policy; + auto ctx = MakeCUDACtx(0); + auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; + param.prefetch_copy = prefetch_copy; - auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); - dmlc::TemporaryDirectory tmpdir; - std::string path = tmpdir.path + "/ellpack.page"; + auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); + dmlc::TemporaryDirectory tmpdir; + std::string path = tmpdir.path + "/ellpack.page"; - std::shared_ptr cuts; - for (auto const &page : m->GetBatches(&ctx, param)) { - cuts = page.Impl()->CutsShared(); - } + std::shared_ptr cuts; + for (auto const &page : m->GetBatches(&ctx, param)) { + cuts = page.Impl()->CutsShared(); + } - ASSERT_EQ(cuts->cut_values_.Device(), ctx.Device()); - ASSERT_TRUE(cuts->cut_values_.DeviceCanRead()); - policy.SetCuts(cuts, ctx.Device()); + ASSERT_EQ(cuts->cut_values_.Device(), ctx.Device()); + ASSERT_TRUE(cuts->cut_values_.DeviceCanRead()); + policy.SetCuts(cuts, ctx.Device()); - std::unique_ptr format{policy.CreatePageFormat()}; + std::unique_ptr format{policy.CreatePageFormat(param)}; + + std::size_t n_bytes{0}; + { + auto fo = policy.CreateWriter(StringView{path}, 0); + for (auto const &ellpack : m->GetBatches(&ctx, param)) { + n_bytes += format->Write(ellpack, fo.get()); + } + } + + EllpackPage page; + auto fi = policy.CreateReader(StringView{path}, static_cast(0), n_bytes); + ASSERT_TRUE(format->Read(&page, fi.get())); - std::size_t n_bytes{0}; - { - auto fo = policy.CreateWriter(StringView{path}, 0); for (auto const &ellpack : m->GetBatches(&ctx, param)) { - n_bytes += format->Write(ellpack, fo.get()); + auto loaded = page.Impl(); + auto orig = ellpack.Impl(); + ASSERT_EQ(loaded->Cuts().Ptrs(), orig->Cuts().Ptrs()); + ASSERT_EQ(loaded->Cuts().MinValues(), orig->Cuts().MinValues()); + ASSERT_EQ(loaded->Cuts().Values(), orig->Cuts().Values()); + ASSERT_EQ(loaded->base_rowid, orig->base_rowid); + ASSERT_EQ(loaded->row_stride, orig->row_stride); + std::vector h_loaded, h_orig; + [[maybe_unused]] auto h_loaded_acc = loaded->GetHostAccessor(&ctx, &h_loaded); + [[maybe_unused]] auto h_orig_acc = orig->GetHostAccessor(&ctx, &h_orig); + ASSERT_EQ(h_loaded, h_orig); } } - - EllpackPage page; - auto fi = policy.CreateReader(StringView{path}, static_cast(0), n_bytes); - ASSERT_TRUE(format->Read(&page, fi.get())); - - for (auto const &ellpack : m->GetBatches(&ctx, param)) { - auto loaded = page.Impl(); - auto orig = ellpack.Impl(); - ASSERT_EQ(loaded->Cuts().Ptrs(), orig->Cuts().Ptrs()); - ASSERT_EQ(loaded->Cuts().MinValues(), orig->Cuts().MinValues()); - ASSERT_EQ(loaded->Cuts().Values(), orig->Cuts().Values()); - ASSERT_EQ(loaded->base_rowid, orig->base_rowid); - ASSERT_EQ(loaded->row_stride, orig->row_stride); - std::vector h_loaded, h_orig; - [[maybe_unused]] auto h_loaded_acc = loaded->GetHostAccessor(&ctx, &h_loaded); - [[maybe_unused]] auto h_orig_acc = orig->GetHostAccessor(&ctx, &h_orig); - ASSERT_EQ(h_loaded, h_orig); - } -} +}; } // anonymous namespace -TEST(EllpackPageRawFormat, DiskIO) { +TEST_P(TestEllpackPageRawFormat, DiskIO) { EllpackMmapStreamPolicy policy{false}; - TestEllpackPageRawFormat(&policy); + this->Run(&policy, this->GetParam()); } -TEST(EllpackPageRawFormat, DiskIOHmm) { +TEST_P(TestEllpackPageRawFormat, DiskIOHmm) { if (common::SupportsPageableMem()) { EllpackMmapStreamPolicy policy{true}; - TestEllpackPageRawFormat(&policy); + this->Run(&policy, this->GetParam()); } else { GTEST_SKIP_("HMM is not supported."); } } -TEST(EllpackPageRawFormat, HostIO) { +TEST_P(TestEllpackPageRawFormat, HostIO) { { EllpackCacheStreamPolicy policy; - TestEllpackPageRawFormat(&policy); + this->Run(&policy, this->GetParam()); } { auto ctx = MakeCUDACtx(0); auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()}; + param.n_prefetch_batches = 1; + param.prefetch_copy = this->GetParam(); + EllpackCacheStreamPolicy policy; std::unique_ptr format{}; Cache cache{false, "name", "ellpack", true}; @@ -92,7 +99,7 @@ TEST(EllpackPageRawFormat, HostIO) { for (auto const &page : p_fmat->GetBatches(&ctx, param)) { if (!format) { policy.SetCuts(page.Impl()->CutsShared(), ctx.Device()); - format = policy.CreatePageFormat(); + format = policy.CreatePageFormat(param); } auto writer = policy.CreateWriter({}, i); auto n_bytes = format->Write(page, writer.get()); @@ -123,4 +130,6 @@ TEST(EllpackPageRawFormat, HostIO) { } } } + +INSTANTIATE_TEST_SUITE_P(EllpackPageRawFormat, TestEllpackPageRawFormat, ::testing::Bool()); } // namespace xgboost::data diff --git a/tests/cpp/data/test_sparse_page_raw_format.cc b/tests/cpp/data/test_sparse_page_raw_format.cc index bd0f97dcc..9f08c202f 100644 --- a/tests/cpp/data/test_sparse_page_raw_format.cc +++ b/tests/cpp/data/test_sparse_page_raw_format.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021-2023, XGBoost contributors + * Copyright 2021-2024, XGBoost contributors */ #include #include // for CSCPage, SortedCSCPage, SparsePage @@ -11,8 +11,6 @@ #include "../../../src/data/sparse_page_writer.h" // for CreatePageFormat #include "../helpers.h" // for RandomDataGenerator #include "dmlc/filesystem.h" // for TemporaryDirectory -#include "dmlc/io.h" // for Stream -#include "gtest/gtest_pred_impl.h" // for Test, AssertionResult, ASSERT_EQ, TEST #include "xgboost/context.h" // for Context namespace xgboost::data {