From f2f7dd87b860ed1d4fcdfb1ba1339a327c6c39d5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 11 Jan 2021 18:04:55 +0800 Subject: [PATCH] Use view for `SparsePage` exclusively. (#6590) --- include/xgboost/data.h | 9 ----- src/common/hist_util.cc | 3 +- src/data/data.cc | 9 ++--- src/data/simple_dmatrix.cc | 5 ++- src/gbm/gblinear.cc | 3 +- src/gbm/gbtree.cc | 5 +-- src/linear/coordinate_common.h | 19 ++++++---- src/linear/updater_gpu_coordinate.cu | 5 +-- src/linear/updater_shotgun.cc | 3 +- src/predictor/cpu_predictor.cc | 10 +++--- src/tree/updater_basemaker-inl.h | 11 +++--- src/tree/updater_colmaker.cc | 9 +++-- src/tree/updater_histmaker.cc | 12 ++++--- src/tree/updater_refresh.cc | 5 +-- tests/cpp/c_api/test_c_api.cc | 12 ++++--- tests/cpp/common/test_hist_util.h | 3 +- tests/cpp/data/test_adapter.cc | 5 +-- tests/cpp/data/test_data.cc | 27 ++++++++------- tests/cpp/data/test_simple_dmatrix.cc | 30 +++++++++------- tests/cpp/data/test_simple_dmatrix.cu | 33 +++++++++++------- tests/cpp/data/test_sparse_page_dmatrix.cc | 40 +++++++++++++--------- tests/cpp/predictor/test_cpu_predictor.cc | 3 +- tests/cpp/tree/test_quantile_hist.cc | 3 +- 23 files changed, 151 insertions(+), 113 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 7226c7b82..7a16b77e2 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -252,15 +252,6 @@ class SparsePage { /*! \brief an instance of sparse vector in the batch */ using Inst = common::Span; - /*! \brief get i-th row from the batch */ - inline Inst operator[](size_t i) const { - const auto& data_vec = data.HostVector(); - const auto& offset_vec = offset.HostVector(); - size_t size = offset_vec[i + 1] - offset_vec[i]; - return {data_vec.data() + offset_vec[i], - static_cast(size)}; - } - HostSparsePageView GetView() const { return {offset.ConstHostSpan(), data.ConstHostSpan()}; } diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 09a09897a..c0031c047 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -78,6 +78,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) { const size_t batch_threads = std::max( size_t(1), std::min(batch.Size(), static_cast(omp_get_max_threads()))); + auto page = batch.GetView(); MemStackAllocator partial_sums(batch_threads); size_t* p_part = partial_sums.Get(); @@ -92,7 +93,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) { size_t sum = 0; for (size_t i = ibegin; i < iend; ++i) { - sum += batch[i].size(); + sum += page[i].size(); row_ptr[rbegin + 1 + i] = sum; } } diff --git a/src/data/data.cc b/src/data/data.cc index c98a4f5c2..4874e1791 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -825,19 +825,20 @@ SparsePage SparsePage::GetTranspose(int num_columns) const { const int nthread = omp_get_max_threads(); builder.InitBudget(num_columns, nthread); long batch_size = static_cast(this->Size()); // NOLINT(*) -#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static) + auto page = this->GetView(); +#pragma omp parallel for default(none) shared(batch_size, builder, page) schedule(static) for (long i = 0; i < batch_size; ++i) { // NOLINT(*) int tid = omp_get_thread_num(); - auto inst = (*this)[i]; + auto inst = page[i]; for (const auto& entry : inst) { builder.AddBudget(entry.index, tid); } } builder.InitStorage(); -#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static) +#pragma omp parallel for default(none) shared(batch_size, builder, page) schedule(static) for (long i = 0; i < batch_size; ++i) { // NOLINT(*) int tid = omp_get_thread_num(); - auto inst = (*this)[i]; + auto inst = page[i]; for (const auto& entry : inst) { builder.Push( entry.index, diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index c39909489..a4b4b583f 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -28,13 +28,12 @@ DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { auto out = new SimpleDMatrix; SparsePage& out_page = out->sparse_page_; for (auto const &page : this->GetBatches()) { - page.data.HostVector(); - page.offset.HostVector(); + auto batch = page.GetView(); auto& h_data = out_page.data.HostVector(); auto& h_offset = out_page.offset.HostVector(); size_t rptr{0}; for (auto ridx : ridxs) { - auto inst = page[ridx]; + auto inst = batch[ridx]; rptr += inst.size(); std::copy(inst.begin(), inst.end(), std::back_inserter(h_data)); h_offset.emplace_back(rptr); diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 2e94e8626..cc2a4a439 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -173,9 +173,10 @@ class GBLinear : public GradientBooster { for (const auto &batch : p_fmat->GetBatches()) { // parallel over local batch const auto nsize = static_cast(batch.Size()); + auto page = batch.GetView(); #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nsize; ++i) { - auto inst = batch[i]; + auto inst = page[i]; auto row_idx = static_cast(batch.base_rowid + i); // loop over output groups for (int gid = 0; gid < ngroup; ++gid) { diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 14fe617cd..a4fb6e28e 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -678,6 +678,7 @@ class Dart : public GBTree { CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group); // start collecting the prediction for (const auto &batch : p_fmat->GetBatches()) { + auto page = batch.GetView(); constexpr int kUnroll = 8; const auto nsize = static_cast(batch.Size()); const bst_omp_uint rest = nsize % kUnroll; @@ -692,7 +693,7 @@ class Dart : public GBTree { ridx[k] = static_cast(batch.base_rowid + i + k); } for (int k = 0; k < kUnroll; ++k) { - inst[k] = batch[i + k]; + inst[k] = page[i + k]; } for (int k = 0; k < kUnroll; ++k) { for (int gid = 0; gid < num_group; ++gid) { @@ -707,7 +708,7 @@ class Dart : public GBTree { for (bst_omp_uint i = nsize - rest; i < nsize; ++i) { RegTree::FVec& feats = thread_temp_[0]; const auto ridx = static_cast(batch.base_rowid + i); - const SparsePage::Inst inst = batch[i]; + const SparsePage::Inst inst = page[i]; for (int gid = 0; gid < num_group; ++gid) { const size_t offset = ridx * num_group + gid; preds[offset] += diff --git a/src/linear/coordinate_common.h b/src/linear/coordinate_common.h index 14f99bb3d..7974babbe 100644 --- a/src/linear/coordinate_common.h +++ b/src/linear/coordinate_common.h @@ -82,7 +82,8 @@ inline std::pair GetGradient(int group_idx, int num_group, int f DMatrix *p_fmat) { double sum_grad = 0.0, sum_hess = 0.0; for (const auto &batch : p_fmat->GetBatches()) { - auto col = batch[fidx]; + auto page = batch.GetView(); + auto col = page[fidx]; const auto ndata = static_cast(col.size()); for (bst_omp_uint j = 0; j < ndata; ++j) { const bst_float v = col[j].fvalue; @@ -111,7 +112,8 @@ inline std::pair GetGradientParallel(int group_idx, int num_grou DMatrix *p_fmat) { double sum_grad = 0.0, sum_hess = 0.0; for (const auto &batch : p_fmat->GetBatches()) { - auto col = batch[fidx]; + auto page = batch.GetView(); + auto col = page[fidx]; const auto ndata = static_cast(col.size()); #pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess) for (bst_omp_uint j = 0; j < ndata; ++j) { @@ -166,7 +168,8 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group, DMatrix *p_fmat) { if (dw == 0.0f) return; for (const auto &batch : p_fmat->GetBatches()) { - auto col = batch[fidx]; + auto page = batch.GetView(); + auto col = page[fidx]; // update grad value const auto num_row = static_cast(col.size()); #pragma omp parallel for schedule(static) @@ -334,9 +337,10 @@ class GreedyFeatureSelector : public FeatureSelector { // Calculate univariate gradient sums std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.)); for (const auto &batch : p_fmat->GetBatches()) { - #pragma omp parallel for schedule(static) + auto page = batch.GetView(); +#pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nfeat; ++i) { - const auto col = batch[i]; + const auto col = page[i]; const bst_uint ndata = col.size(); auto &sums = gpair_sums_[group_idx * nfeat + i]; for (bst_uint j = 0u; j < ndata; ++j) { @@ -399,10 +403,11 @@ class ThriftyFeatureSelector : public FeatureSelector { // Calculate univariate gradient sums std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.)); for (const auto &batch : p_fmat->GetBatches()) { -// column-parallel is usually faster than row-parallel + auto page = batch.GetView(); + // column-parallel is usually fastaer than row-parallel #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nfeat; ++i) { - const auto col = batch[i]; + const auto col = page[i]; const bst_uint ndata = col.size(); for (bst_uint gid = 0u; gid < ngroup; ++gid) { auto &sums = gpair_sums_[gid * nfeat + i]; diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index e7db2dc02..685ec85f9 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -60,6 +60,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT CHECK(p_fmat->SingleColBlock()); SparsePage const& batch = *(p_fmat->GetBatches().begin()); + auto page = batch.GetView(); if (IsEmpty()) { return; @@ -72,7 +73,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT row_ptr_ = {0}; // iterate through columns for (size_t fidx = 0; fidx < batch.Size(); fidx++) { - common::Span col = batch[fidx]; + common::Span col = page[fidx]; auto cmp = [](Entry e1, Entry e2) { return e1.index < e2.index; }; @@ -89,7 +90,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT data_.resize(row_ptr_.back()); gpair_.resize(num_row_ * model_param.num_output_group); for (size_t fidx = 0; fidx < batch.Size(); fidx++) { - auto col = batch[fidx]; + auto col = page[fidx]; auto seg = column_segments[fidx]; dh::safe_cuda(cudaMemcpy( data_.data().get() + row_ptr_[fidx], diff --git a/src/linear/updater_shotgun.cc b/src/linear/updater_shotgun.cc index 08604ce73..a9321ba9a 100644 --- a/src/linear/updater_shotgun.cc +++ b/src/linear/updater_shotgun.cc @@ -52,6 +52,7 @@ class ShotgunUpdater : public LinearUpdater { selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0); for (const auto &batch : p_fmat->GetBatches()) { + auto page = batch.GetView(); const auto nfeat = static_cast(batch.Size()); #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nfeat; ++i) { @@ -60,7 +61,7 @@ class ShotgunUpdater : public LinearUpdater { param_.reg_lambda_denorm); if (ii < 0) continue; const bst_uint fid = ii; - auto col = batch[ii]; + auto col = page[ii]; for (int gid = 0; gid < ngroup; ++gid) { double sum_grad = 0.0, sum_hess = 0.0; for (auto& c : col) { diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 3adcb63f9..80c9b0ee9 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -360,18 +360,19 @@ class CPUPredictor : public Predictor { // start collecting the prediction for (const auto &batch : p_fmat->GetBatches()) { // parallel over local batch + auto page = batch.GetView(); const auto nsize = static_cast(batch.Size()); #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nsize; ++i) { const int tid = omp_get_thread_num(); auto ridx = static_cast(batch.base_rowid + i); RegTree::FVec &feats = thread_temp_[tid]; - feats.Fill(batch[i]); + feats.Fill(page[i]); for (unsigned j = 0; j < ntree_limit; ++j) { int tid = model.trees[j]->GetLeafIndex(feats); preds[ridx * ntree_limit + j] = static_cast(tid); } - feats.Drop(batch[i]); + feats.Drop(page[i]); } } } @@ -407,6 +408,7 @@ class CPUPredictor : public Predictor { const std::vector& base_margin = info.base_margin_.HostVector(); // start collecting the contributions for (const auto &batch : p_fmat->GetBatches()) { + auto page = batch.GetView(); // parallel over local batch const auto nsize = static_cast(batch.Size()); #pragma omp parallel for schedule(static) @@ -417,7 +419,7 @@ class CPUPredictor : public Predictor { // loop over all classes for (int gid = 0; gid < ngroup; ++gid) { bst_float* p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns]; - feats.Fill(batch[i]); + feats.Fill(page[i]); // calculate contributions for (unsigned j = 0; j < ntree_limit; ++j) { std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); @@ -435,7 +437,7 @@ class CPUPredictor : public Predictor { (tree_weights == nullptr ? 1 : (*tree_weights)[j]); } } - feats.Drop(batch[i]); + feats.Drop(page[i]); // add base margin to BIAS if (base_margin.size() != 0) { p_contribs[ncolumns - 1] += base_margin[row_idx * ngroup + gid]; diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index 66ab91982..b38ac2e7c 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -59,8 +59,9 @@ class BaseMaker: public TreeUpdater { -std::numeric_limits::max()); // start accumulating statistics for (const auto &batch : p_fmat->GetBatches()) { + auto page = batch.GetView(); for (bst_uint fid = 0; fid < batch.Size(); ++fid) { - auto c = batch[fid]; + auto c = page[fid]; if (c.size() != 0) { CHECK_LT(fid * 2, fminmax_.size()); fminmax_[fid * 2 + 0] = @@ -249,8 +250,9 @@ class BaseMaker: public TreeUpdater { inline void CorrectNonDefaultPositionByBatch( const SparsePage &batch, const std::vector &sorted_split_set, const RegTree &tree) { + auto page = batch.GetView(); for (size_t fid = 0; fid < batch.Size(); ++fid) { - auto col = batch[fid]; + auto col = page[fid]; auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid); if (it != sorted_split_set.end() && *it == fid) { @@ -308,10 +310,11 @@ class BaseMaker: public TreeUpdater { std::vector fsplits; this->GetSplitSet(nodes, tree, &fsplits); for (const auto &batch : p_fmat->GetBatches()) { + auto page = batch.GetView(); for (auto fid : fsplits) { - auto col = batch[fid]; + auto col = page[fid]; const auto ndata = static_cast(col.size()); - #pragma omp parallel for schedule(static) +#pragma omp parallel for schedule(static) for (bst_omp_uint j = 0; j < ndata; ++j) { const bst_uint ridx = col[j].index; const bst_float fvalue = col[j].fvalue; diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 70a1f4dab..1997ecaf0 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -77,8 +77,9 @@ class ColMaker: public TreeUpdater { if (column_densities_.empty()) { std::vector column_size(dmat->Info().num_col_); for (const auto &batch : dmat->GetBatches()) { + auto page = batch.GetView(); for (auto i = 0u; i < batch.Size(); i++) { - column_size[i] += batch[i].size(); + column_size[i] += page[i].size(); } } column_densities_.resize(column_size.size()); @@ -447,13 +448,14 @@ class ColMaker: public TreeUpdater { #endif // defined(_OPENMP) { dmlc::OMPException omp_handler; + auto page = batch.GetView(); #pragma omp parallel for schedule(dynamic, batch_size) for (bst_omp_uint i = 0; i < num_features; ++i) { omp_handler.Run([&]() { auto evaluator = tree_evaluator_.GetEvaluator(); bst_feature_t const fid = feat_set[i]; int32_t const tid = omp_get_thread_num(); - auto c = batch[fid]; + auto c = page[fid]; const bool ind = c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue; if (colmaker_train_param_.NeedForwardSearch( @@ -562,8 +564,9 @@ class ColMaker: public TreeUpdater { std::sort(fsplits.begin(), fsplits.end()); fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); for (const auto &batch : p_fmat->GetBatches()) { + auto page = batch.GetView(); for (auto fid : fsplits) { - auto col = batch[fid]; + auto col = page[fid]; const auto ndata = static_cast(col.size()); #pragma omp parallel for schedule(static) for (bst_omp_uint j = 0; j < ndata; ++j) { diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 7edceee72..fff642ba8 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -338,6 +338,7 @@ class CQHistMaker: public HistMaker { thread_hist_.resize(omp_get_max_threads()); // start accumulating statistics for (const auto &batch : p_fmat->GetBatches()) { + auto page = batch.GetView(); // start enumeration const auto nsize = static_cast(fset.size()); #pragma omp parallel for schedule(dynamic, 1) @@ -345,7 +346,7 @@ class CQHistMaker: public HistMaker { int fid = fset[i]; int offset = feat2workindex_[fid]; if (offset >= 0) { - this->UpdateHistCol(gpair, batch[fid], info, tree, + this->UpdateHistCol(gpair, page[fid], info, tree, fset, offset, &thread_hist_[omp_get_thread_num()]); } @@ -413,15 +414,15 @@ class CQHistMaker: public HistMaker { for (const auto &batch : p_fmat->GetBatches()) { // TWOPASS: use the real set + split set in the column iteration. this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree); - + auto page = batch.GetView(); // start enumeration const auto nsize = static_cast(work_set_.size()); - #pragma omp parallel for schedule(dynamic, 1) +#pragma omp parallel for schedule(dynamic, 1) for (bst_omp_uint i = 0; i < nsize; ++i) { int fid = work_set_[i]; int offset = feat2workindex_[fid]; if (offset >= 0) { - this->UpdateSketchCol(gpair, batch[fid], tree, + this->UpdateSketchCol(gpair, page[fid], tree, work_set_size, offset, &thread_sketch_[omp_get_thread_num()]); } @@ -696,6 +697,7 @@ class GlobalProposalHistMaker: public CQHistMaker { for (const auto &batch : p_fmat->GetBatches()) { // TWOPASS: use the real set + split set in the column iteration. this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree); + auto page = batch.GetView(); // start enumeration const auto nsize = static_cast(this->work_set_.size()); @@ -704,7 +706,7 @@ class GlobalProposalHistMaker: public CQHistMaker { int fid = this->work_set_[i]; int offset = this->feat2workindex_[fid]; if (offset >= 0) { - this->UpdateHistCol(gpair, batch[fid], info, tree, + this->UpdateHistCol(gpair, page[fid], info, tree, fset, offset, &this->thread_hist_[omp_get_thread_num()]); } diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index b51fb3a3c..20485c670 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -69,11 +69,12 @@ class TreeRefresher: public TreeUpdater { const MetaInfo &info = p_fmat->Info(); // start accumulating statistics for (const auto &batch : p_fmat->GetBatches()) { + auto page = batch.GetView(); CHECK_LT(batch.Size(), std::numeric_limits::max()); const auto nbatch = static_cast(batch.Size()); - #pragma omp parallel for schedule(static) +#pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nbatch; ++i) { - SparsePage::Inst inst = batch[i]; + SparsePage::Inst inst = page[i]; const int tid = omp_get_thread_num(); const auto ridx = static_cast(batch.base_rowid + i); RegTree::FVec &feats = fvec_temp[tid]; diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 0b5c94741..83dedd2da 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -30,10 +30,11 @@ TEST(CAPI, XGDMatrixCreateFromMatDT) { ASSERT_EQ(info.num_nonzero_, 6ul); for (const auto &batch : (*dmat)->GetBatches()) { - ASSERT_EQ(batch[0][0].fvalue, 0.0f); - ASSERT_EQ(batch[0][1].fvalue, -4.0f); - ASSERT_EQ(batch[2][0].fvalue, 3.0f); - ASSERT_EQ(batch[2][1].fvalue, 0.0f); + auto page = batch.GetView(); + ASSERT_EQ(page[0][0].fvalue, 0.0f); + ASSERT_EQ(page[0][1].fvalue, -4.0f); + ASSERT_EQ(page[2][0].fvalue, 3.0f); + ASSERT_EQ(page[2][1].fvalue, 0.0f); } delete dmat; @@ -62,8 +63,9 @@ TEST(CAPI, XGDMatrixCreateFromMatOmp) { ASSERT_EQ(info.num_nonzero_, num_cols * row - num_missing); for (const auto &batch : (*dmat)->GetBatches()) { + auto page = batch.GetView(); for (size_t i = 0; i < batch.Size(); i++) { - auto inst = batch[i]; + auto inst = page[i]; for (auto e : inst) { ASSERT_EQ(e.fvalue, 1.5); } diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 9b025ed9b..a8636c854 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -176,9 +176,10 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, // Collect data into columns std::vector> columns(dmat->Info().num_col_); for (auto& batch : dmat->GetBatches()) { + auto page = batch.GetView(); ASSERT_GT(batch.Size(), 0ul); for (auto i = 0ull; i < batch.Size(); i++) { - for (auto e : batch[i]) { + for (auto e : page[i]) { columns[e.index].push_back(e.fvalue); } } diff --git a/tests/cpp/data/test_adapter.cc b/tests/cpp/data/test_adapter.cc index bb8d8b627..fb1cd0249 100644 --- a/tests/cpp/data/test_adapter.cc +++ b/tests/cpp/data/test_adapter.cc @@ -47,7 +47,8 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) { EXPECT_EQ(dmat.Info().num_nonzero_, 8); auto &batch = *dmat.GetBatches().begin(); - auto inst = batch[0]; + auto page = batch.GetView(); + auto inst = page[0]; EXPECT_EQ(inst[0].fvalue, 1); EXPECT_EQ(inst[0].index, 0); EXPECT_EQ(inst[1].fvalue, 3); @@ -57,7 +58,7 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) { EXPECT_EQ(inst[3].fvalue, 7); EXPECT_EQ(inst[3].index, 3); - inst = batch[1]; + inst = page[1]; EXPECT_EQ(inst[0].fvalue, 2); EXPECT_EQ(inst[0].index, 0); EXPECT_EQ(inst[1].fvalue, 4); diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index c63c4b1d7..195dd6965 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -11,9 +11,9 @@ namespace xgboost { TEST(SparsePage, PushCSC) { std::vector offset {0}; std::vector data; - SparsePage page; - page.offset.HostVector() = offset; - page.data.HostVector() = data; + SparsePage batch; + batch.offset.HostVector() = offset; + batch.data.HostVector() = data; offset = {0, 1, 4}; for (size_t i = 0; i < offset.back(); ++i) { @@ -24,25 +24,26 @@ TEST(SparsePage, PushCSC) { other.offset.HostVector() = offset; other.data.HostVector() = data; - page.PushCSC(other); + batch.PushCSC(other); - ASSERT_EQ(page.offset.HostVector().size(), offset.size()); - ASSERT_EQ(page.data.HostVector().size(), data.size()); + ASSERT_EQ(batch.offset.HostVector().size(), offset.size()); + ASSERT_EQ(batch.data.HostVector().size(), data.size()); for (size_t i = 0; i < offset.size(); ++i) { - ASSERT_EQ(page.offset.HostVector()[i], offset[i]); + ASSERT_EQ(batch.offset.HostVector()[i], offset[i]); } for (size_t i = 0; i < data.size(); ++i) { - ASSERT_EQ(page.data.HostVector()[i].index, data[i].index); + ASSERT_EQ(batch.data.HostVector()[i].index, data[i].index); } - page.PushCSC(other); - ASSERT_EQ(page.offset.HostVector().size(), offset.size()); - ASSERT_EQ(page.data.Size(), data.size() * 2); + batch.PushCSC(other); + ASSERT_EQ(batch.offset.HostVector().size(), offset.size()); + ASSERT_EQ(batch.data.Size(), data.size() * 2); for (size_t i = 0; i < offset.size(); ++i) { - ASSERT_EQ(page.offset.HostVector()[i], offset[i] * 2); + ASSERT_EQ(batch.offset.HostVector()[i], offset[i] * 2); } + auto page = batch.GetView(); auto inst = page[0]; ASSERT_EQ(inst.size(), 2ul); for (auto entry : inst) { @@ -78,7 +79,7 @@ TEST(SparsePage, PushCSCAfterTranspose) { // The feature value for a feature in each row should be identical, as that is // how the dmatrix has been created for (size_t i = 0; i < page.Size(); ++i) { - auto inst = page[i]; + auto inst = page.GetView()[i]; for (size_t j = 1; j < inst.size(); ++j) { ASSERT_EQ(inst[0].fvalue, inst[j].fvalue); } diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 8fdd0d09f..3147395a6 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -39,7 +39,8 @@ TEST(SimpleDMatrix, RowAccess) { EXPECT_EQ(row_count, dmat->Info().num_row_); // Test the data read into the first row auto &batch = *dmat->GetBatches().begin(); - auto first_row = batch[0]; + auto page = batch.GetView(); + auto first_row = page[0]; ASSERT_EQ(first_row.size(), 3); EXPECT_EQ(first_row[2].index, 2); EXPECT_EQ(first_row[2].fvalue, 20); @@ -143,8 +144,9 @@ TEST(SimpleDMatrix, FromDense) { EXPECT_EQ(dmat.Info().num_nonzero_, 6); for (auto &batch : dmat.GetBatches()) { + auto page = batch.GetView(); for (auto i = 0ull; i < batch.Size(); i++) { - auto inst = batch[i]; + auto inst = page[i]; for (auto j = 0ull; j < inst.size(); j++) { EXPECT_EQ(inst[j].fvalue, data[i * n + j]); EXPECT_EQ(inst[j].index, j); @@ -165,19 +167,20 @@ TEST(SimpleDMatrix, FromCSC) { EXPECT_EQ(dmat.Info().num_nonzero_, 5); auto &batch = *dmat.GetBatches().begin(); - auto inst = batch[0]; + auto page = batch.GetView(); + auto inst = page[0]; EXPECT_EQ(inst[0].fvalue, 1); EXPECT_EQ(inst[0].index, 0); EXPECT_EQ(inst[1].fvalue, 2); EXPECT_EQ(inst[1].index, 1); - inst = batch[1]; + inst = page[1]; EXPECT_EQ(inst[0].fvalue, 3); EXPECT_EQ(inst[0].index, 0); EXPECT_EQ(inst[1].fvalue, 4); EXPECT_EQ(inst[1].index, 1); - inst = batch[2]; + inst = page[2]; EXPECT_EQ(inst[0].fvalue, 5); EXPECT_EQ(inst[0].index, 1); } @@ -194,11 +197,12 @@ TEST(SimpleDMatrix, FromFile) { std::unique_ptr> parser( dmlc::Parser::Create(filename.c_str(), 0, 1, "auto")); - auto verify_batch = [kExpectedNumRow](SparsePage const &batch) { + auto verify_batch = [kExpectedNumRow](SparsePage const &page) { + auto batch = page.GetView(); EXPECT_EQ(batch.Size(), kExpectedNumRow); - EXPECT_EQ(batch.offset.HostVector(), + EXPECT_EQ(page.offset.HostVector(), std::vector({0, 3, 6, 9, 12, 15, 15})); - EXPECT_EQ(batch.base_rowid, 0); + EXPECT_EQ(page.base_rowid, 0); for (auto i = 0ull; i < batch.Size() - 1; i++) { if (i % 2 == 0) { @@ -251,8 +255,10 @@ TEST(SimpleDMatrix, Slice) { ASSERT_EQ(out->Info().labels_upper_bound_.Size(), ridxs.size()); ASSERT_EQ(out->Info().base_margin_.Size(), ridxs.size() * kClasses); - for (auto const& in_page : p_m->GetBatches()) { - for (auto const &out_page : out->GetBatches()) { + for (auto const& in_batch : p_m->GetBatches()) { + auto in_page = in_batch.GetView(); + for (auto const &out_batch : out->GetBatches()) { + auto out_page = out_batch.GetView(); for (size_t i = 0; i < ridxs.size(); ++i) { auto ridx = ridxs[i]; auto out_inst = out_page[i]; @@ -305,8 +311,8 @@ TEST(SimpleDMatrix, SaveLoadBinary) { auto row_iter = dmat->GetBatches().begin(); auto row_iter_read = dmat_read->GetBatches().begin(); // Test the data read into the first row - auto first_row = (*row_iter)[0]; - auto first_row_read = (*row_iter_read)[0]; + auto first_row = (*row_iter).GetView()[0]; + auto first_row_read = (*row_iter_read).GetView()[0]; EXPECT_EQ(first_row.size(), first_row_read.size()); EXPECT_EQ(first_row[2].index, first_row_read[2].index); EXPECT_EQ(first_row[2].fvalue, first_row_read[2].fvalue); diff --git a/tests/cpp/data/test_simple_dmatrix.cu b/tests/cpp/data/test_simple_dmatrix.cu index aff977bd2..d74f5b150 100644 --- a/tests/cpp/data/test_simple_dmatrix.cu +++ b/tests/cpp/data/test_simple_dmatrix.cu @@ -35,8 +35,9 @@ TEST(SimpleDMatrix, FromColumnarDenseBasic) { void TestDenseColumn(DMatrix* dmat, size_t n_rows, size_t n_cols) { for (auto& batch : dmat->GetBatches()) { + auto page = batch.GetView(); for (auto i = 0ull; i < batch.Size(); i++) { - auto inst = batch[i]; + auto inst = page[i]; for (auto j = 0ull; j < inst.size(); j++) { EXPECT_EQ(inst[j].fvalue, i * 2); EXPECT_EQ(inst[j].index, j); @@ -162,8 +163,9 @@ TEST(SimpleDMatrix, FromColumnarWithEmptyRows) { -1); for (auto& batch : dmat.GetBatches()) { + auto page = batch.GetView(); for (auto i = 0ull; i < batch.Size(); i++) { - auto inst = batch[i]; + auto inst = page[i]; for (auto j = 0ull; j < inst.size(); j++) { EXPECT_EQ(inst[j].fvalue, i); EXPECT_EQ(inst[j].index, j); @@ -257,8 +259,9 @@ TEST(SimpleCSRSource, FromColumnarSparse) { data::CudfAdapter adapter(str); data::SimpleDMatrix dmat(&adapter, 2.0, -1); for (auto& batch : dmat.GetBatches()) { + auto page = batch.GetView(); for (auto i = 0ull; i < batch.Size(); i++) { - auto inst = batch[i]; + auto inst = page[i]; for (auto e : inst) { ASSERT_NE(e.fvalue, 2.0); } @@ -304,8 +307,9 @@ TEST(SimpleDMatrix, FromColumnarSparseBasic) { EXPECT_EQ(dmat.Info().num_nonzero_, 32); for (auto& batch : dmat.GetBatches()) { + auto page = batch.GetView(); for (auto i = 0ull; i < batch.Size(); i++) { - auto inst = batch[i]; + auto inst = page[i]; for (auto j = 0ull; j < inst.size(); j++) { EXPECT_EQ(inst[j].fvalue, i * 2); EXPECT_EQ(inst[j].index, j); @@ -329,8 +333,9 @@ TEST(SimpleDMatrix, FromCupy){ EXPECT_EQ(dmat.Info().num_nonzero_, rows*cols); for (auto& batch : dmat.GetBatches()) { + auto page = batch.GetView(); for (auto i = 0ull; i < batch.Size(); i++) { - auto inst = batch[i]; + auto inst = page[i]; for (auto j = 0ull; j < inst.size(); j++) { EXPECT_EQ(inst[j].fvalue, i * cols + j); EXPECT_EQ(inst[j].index, j); @@ -354,12 +359,14 @@ TEST(SimpleDMatrix, FromCupySparse){ EXPECT_EQ(dmat.Info().num_row_, rows); EXPECT_EQ(dmat.Info().num_nonzero_, rows * cols - 2); auto& batch = *dmat.GetBatches().begin(); - auto inst0 = batch[0]; - auto inst1 = batch[1]; - EXPECT_EQ(batch[0].size(), 1); - EXPECT_EQ(batch[1].size(), 1); - EXPECT_EQ(batch[0][0].fvalue, 0.0f); - EXPECT_EQ(batch[0][0].index, 0); - EXPECT_EQ(batch[1][0].fvalue, 3.0f); - EXPECT_EQ(batch[1][0].index, 1); + auto page = batch.GetView(); + + auto inst0 = page[0]; + auto inst1 = page[1]; + EXPECT_EQ(page[0].size(), 1); + EXPECT_EQ(page[1].size(), 1); + EXPECT_EQ(page[0][0].fvalue, 0.0f); + EXPECT_EQ(page[0][0].index, 0); + EXPECT_EQ(page[1][0].fvalue, 3.0f); + EXPECT_EQ(page[1][0].index, 1); } diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 3e6e46d7b..f20e259fd 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -39,7 +39,8 @@ TEST(SparsePageDMatrix, RowAccess) { // Test the data read into the first row auto &batch = *dmat->GetBatches().begin(); - auto first_row = batch[0]; + auto page = batch.GetView(); + auto first_row = page[0]; ASSERT_EQ(first_row.size(), 3ul); EXPECT_EQ(first_row[2].index, 2u); EXPECT_EQ(first_row[2].fvalue, 20); @@ -54,16 +55,18 @@ TEST(SparsePageDMatrix, ColAccess) { // Loop over the batches and assert the data is as expected for (auto const &col_batch : dmat->GetBatches()) { - EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_); - EXPECT_EQ(col_batch[1][0].fvalue, 10.0f); - EXPECT_EQ(col_batch[1].size(), 1); + auto col_page = col_batch.GetView(); + EXPECT_EQ(col_page.Size(), dmat->Info().num_col_); + EXPECT_EQ(col_page[1][0].fvalue, 10.0f); + EXPECT_EQ(col_page[1].size(), 1); } // Loop over the batches and assert the data is as expected for (auto const &col_batch : dmat->GetBatches()) { - EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_); - EXPECT_EQ(col_batch[1][0].fvalue, 10.0f); - EXPECT_EQ(col_batch[1].size(), 1); + auto col_page = col_batch.GetView(); + EXPECT_EQ(col_page.Size(), dmat->Info().num_col_); + EXPECT_EQ(col_page[1][0].fvalue, 10.0f); + EXPECT_EQ(col_page[1].size(), 1); } EXPECT_TRUE(FileExists(tmp_file + ".cache")); @@ -238,8 +241,9 @@ TEST(SparsePageDMatrix, FromDense) { EXPECT_EQ(dmat.Info().num_nonzero_, 6); for (auto &batch : dmat.GetBatches()) { + auto page = batch.GetView(); for (auto i = 0ull; i < batch.Size(); i++) { - auto inst = batch[i]; + auto inst = page[i]; for (auto j = 0ull; j < inst.size(); j++) { EXPECT_EQ(inst[j].fvalue, data[i * n + j]); EXPECT_EQ(inst[j].index, j); @@ -262,19 +266,20 @@ TEST(SparsePageDMatrix, FromCSC) { EXPECT_EQ(dmat.Info().num_nonzero_, 5); auto &batch = *dmat.GetBatches().begin(); - auto inst = batch[0]; + auto page = batch.GetView(); + auto inst = page[0]; EXPECT_EQ(inst[0].fvalue, 1); EXPECT_EQ(inst[0].index, 0); EXPECT_EQ(inst[1].fvalue, 2); EXPECT_EQ(inst[1].index, 1); - inst = batch[1]; + inst = page[1]; EXPECT_EQ(inst[0].fvalue, 3); EXPECT_EQ(inst[0].index, 0); EXPECT_EQ(inst[1].fvalue, 4); EXPECT_EQ(inst[1].index, 1); - inst = batch[2]; + inst = page[2]; EXPECT_EQ(inst[0].fvalue, 5); EXPECT_EQ(inst[0].index, 1); } @@ -294,19 +299,20 @@ TEST(SparsePageDMatrix, FromFile) { for (auto &batch : dmat.GetBatches()) { std::vector expected_offset(batch.Size() + 1); + auto page = batch.GetView(); int n = -3; std::generate(expected_offset.begin(), expected_offset.end(), [&n] { return n += 3; }); EXPECT_EQ(batch.offset.HostVector(), expected_offset); if (batch.base_rowid % 2 == 0) { - EXPECT_EQ(batch[0][0].index, 0); - EXPECT_EQ(batch[0][1].index, 1); - EXPECT_EQ(batch[0][2].index, 2); + EXPECT_EQ(page[0][0].index, 0); + EXPECT_EQ(page[0][1].index, 1); + EXPECT_EQ(page[0][2].index, 2); } else { - EXPECT_EQ(batch[0][0].index, 0); - EXPECT_EQ(batch[0][1].index, 3); - EXPECT_EQ(batch[0][2].index, 4); + EXPECT_EQ(page[0][0].index, 0); + EXPECT_EQ(page[0][1].index, 3); + EXPECT_EQ(page[0][2].index, 4); } } } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 3d5eecb17..510ce073d 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -39,9 +39,10 @@ TEST(CpuPredictor, Basic) { // Test predict instance auto const &batch = *dmat->GetBatches().begin(); + auto page = batch.GetView(); for (size_t i = 0; i < batch.Size(); i++) { std::vector instance_out_predictions; - cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model); + cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model); ASSERT_EQ(instance_out_predictions[0], 1.5); } diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index de24f23f5..9167e5747 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -72,12 +72,13 @@ class QuantileHistMock : public QuantileHistMaker { ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()), gmat.cut.Ptrs().back()); for (const auto& batch : p_fmat->GetBatches()) { + auto page = batch.GetView(); for (size_t i = 0; i < batch.Size(); ++i) { const size_t rid = batch.base_rowid + i; ASSERT_LT(rid, num_row); const size_t gmat_row_offset = gmat.row_ptr[rid]; ASSERT_LT(gmat_row_offset, gmat.index.Size()); - SparsePage::Inst inst = batch[i]; + SparsePage::Inst inst = page[i]; ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]); for (size_t j = 0; j < inst.size(); ++j) { // Each entry of GHistIndexMatrix represents a bin ID