diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index fbe1ee4dc..c14da59a7 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -15,7 +15,6 @@ #include "random.h" #include "column_matrix.h" #include "quantile.h" -#include "./../tree/updater_quantile_hist.h" #include "../data/gradient_index.h" #if defined(XGBOOST_MM_PREFETCH_PRESENT) diff --git a/src/common/hist_util.h b/src/common/hist_util.h index c2ff58593..e066ed3a3 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -92,18 +92,20 @@ class HistogramCuts { // Return the index of a cut point that is strictly greater than the input // value, or the last available index if none exists - BinIdx SearchBin(float value, uint32_t column_id) const { - auto beg = cut_ptrs_.ConstHostVector().at(column_id); - auto end = cut_ptrs_.ConstHostVector().at(column_id + 1); - const auto &values = cut_values_.ConstHostVector(); + BinIdx SearchBin(float value, uint32_t column_id, std::vector const& ptrs, + std::vector const& values) const { + auto end = ptrs[column_id + 1]; + auto beg = ptrs[column_id]; auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value); BinIdx idx = it - values.cbegin(); - if (idx == end) { - idx -= 1; - } + idx -= !!(idx == end); return idx; } + BinIdx SearchBin(float value, uint32_t column_id) const { + return this->SearchBin(value, column_id, Ptrs(), Values()); + } + /** * \brief Search the bin index for numerical feature. */ @@ -129,7 +131,13 @@ class HistogramCuts { } }; -inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, +/** + * \brief Run CPU sketching on DMatrix. + * + * \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient + * but consumes more memory. + */ +inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sorted = false, Span const hessian = {}) { HistogramCuts out; auto const& info = m->Info(); @@ -146,13 +154,23 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, reduced[i] += entries_per_column[i]; } } - HostSketchContainer container(reduced, max_bins, - m->Info().feature_types.ConstHostSpan(), - HostSketchContainer::UseGroup(info), threads); - for (auto const &page : m->GetBatches()) { - container.PushRowPage(page, info, hessian); + + if (!use_sorted) { + HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), + hessian, threads); + for (auto const& page : m->GetBatches()) { + container.PushRowPage(page, info, hessian); + } + container.MakeCuts(&out); + } else { + SortedSketchContainer container{ + max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, threads}; + for (auto const& page : m->GetBatches()) { + container.PushColPage(page, info, hessian); + } + container.MakeCuts(&out); } - container.MakeCuts(&out); + return out; } diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 4cb1d3a5c..8780c7539 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -12,40 +12,34 @@ namespace xgboost { namespace common { -HostSketchContainer::HostSketchContainer( - std::vector columns_size, int32_t max_bins, - common::Span feature_types, bool use_group, - int32_t n_threads) +template +SketchContainerImpl::SketchContainerImpl(std::vector columns_size, + int32_t max_bins, + common::Span feature_types, + bool use_group, int32_t n_threads) : feature_types_(feature_types.cbegin(), feature_types.cend()), - columns_size_{std::move(columns_size)}, max_bins_{max_bins}, - use_group_ind_{use_group}, n_threads_{n_threads} { + columns_size_{std::move(columns_size)}, + max_bins_{max_bins}, + use_group_ind_{use_group}, + n_threads_{n_threads} { monitor_.Init(__func__); CHECK_NE(columns_size_.size(), 0); sketches_.resize(columns_size_.size()); CHECK_GE(n_threads_, 1); categories_.resize(columns_size_.size()); - ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) { - auto n_bins = std::min(static_cast(max_bins_), columns_size_[i]); - n_bins = std::max(n_bins, static_cast(1)); - auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); - if (!IsCat(this->feature_types_, i)) { - sketches_[i].Init(columns_size_[i], eps); - sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2); - } - }); } -std::vector -HostSketchContainer::CalcColumnSize(SparsePage const &batch, - bst_feature_t const n_columns, - size_t const nthreads) { +template +std::vector SketchContainerImpl::CalcColumnSize(SparsePage const &batch, + bst_feature_t const n_columns, + size_t const nthreads) { auto page = batch.GetView(); std::vector> column_sizes(nthreads); for (auto &column : column_sizes) { column.resize(n_columns, 0); } - ParallelFor(omp_ulong(page.Size()), nthreads, [&](omp_ulong i) { + ParallelFor(page.Size(), nthreads, [&](omp_ulong i) { auto &local_column_sizes = column_sizes.at(omp_get_thread_num()); auto row = page[i]; auto const *p_row = row.data(); @@ -54,7 +48,7 @@ HostSketchContainer::CalcColumnSize(SparsePage const &batch, } }); std::vector entries_per_columns(n_columns, 0); - ParallelFor(bst_omp_uint(n_columns), nthreads, [&](bst_omp_uint i) { + ParallelFor(n_columns, nthreads, [&](bst_omp_uint i) { for (auto const &thread : column_sizes) { entries_per_columns[i] += thread[i]; } @@ -62,8 +56,10 @@ HostSketchContainer::CalcColumnSize(SparsePage const &batch, return entries_per_columns; } -std::vector HostSketchContainer::LoadBalance( - SparsePage const &batch, bst_feature_t n_columns, size_t const nthreads) { +template +std::vector SketchContainerImpl::LoadBalance(SparsePage const &batch, + bst_feature_t n_columns, + size_t const nthreads) { /* Some sparse datasets have their mass concentrating on small number of features. To * avoid waiting for a few threads running forever, we here distribute different number * of columns to different threads according to number of entries. @@ -101,9 +97,8 @@ std::vector HostSketchContainer::LoadBalance( namespace { // Function to merge hessian and sample weights -std::vector MergeWeights(MetaInfo const &info, - Span const hessian, - bool use_group, int32_t n_threads) { +std::vector MergeWeights(MetaInfo const &info, Span hessian, bool use_group, + int32_t n_threads) { CHECK_EQ(hessian.size(), info.num_row_); std::vector results(hessian.size()); auto const &group_ptr = info.group_ptr_; @@ -148,8 +143,9 @@ std::vector UnrollGroupWeights(MetaInfo const &info) { } } // anonymous namespace -void HostSketchContainer::PushRowPage( - SparsePage const &page, MetaInfo const &info, Span hessian) { +template +void SketchContainerImpl::PushRowPage(SparsePage const &page, MetaInfo const &info, + Span hessian) { monitor_.Start(__func__); bst_feature_t n_columns = info.num_col_; auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_; @@ -216,11 +212,12 @@ void HostSketchContainer::PushRowPage( monitor_.Stop(__func__); } -void HostSketchContainer::GatherSketchInfo( - std::vector const &reduced, +template +void SketchContainerImpl::GatherSketchInfo( + std::vector const &reduced, std::vector *p_worker_segments, std::vector *p_sketches_scan, - std::vector *p_global_sketches) { + std::vector *p_global_sketches) { auto& worker_segments = *p_worker_segments; worker_segments.resize(1, 0); auto world = rabit::GetWorldSize(); @@ -251,8 +248,8 @@ void HostSketchContainer::GatherSketchInfo( auto total = worker_segments.back(); auto& global_sketches = *p_global_sketches; - global_sketches.resize(total, WQSketch::Entry{0, 0, 0, 0}); - auto worker_sketch = Span{global_sketches}.subspan( + global_sketches.resize(total, typename WQSketch::Entry{0, 0, 0, 0}); + auto worker_sketch = Span{global_sketches}.subspan( worker_segments[rank], worker_segments[rank + 1] - worker_segments[rank]); size_t cursor = 0; for (auto const &sketch : reduced) { @@ -261,14 +258,15 @@ void HostSketchContainer::GatherSketchInfo( cursor += sketch.size; } - static_assert(sizeof(WQSketch::Entry) / 4 == sizeof(float), ""); + static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float), ""); rabit::Allreduce( reinterpret_cast(global_sketches.data()), - global_sketches.size() * sizeof(WQSketch::Entry) / sizeof(float)); + global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)); } -void HostSketchContainer::AllReduce( - std::vector *p_reduced, +template +void SketchContainerImpl::AllReduce( + std::vector *p_reduced, std::vector* p_num_cuts) { monitor_.Start(__func__); auto& num_cuts = *p_num_cuts; @@ -291,7 +289,7 @@ void HostSketchContainer::AllReduce( std::min(global_column_size[i], static_cast(max_bins_ * WQSketch::kFactor))); if (global_column_size[i] != 0) { - WQSketch::SummaryContainer out; + typename WQSketch::SummaryContainer out; sketches_[i].GetSummary(&out); reduced[i].Reserve(intermediate_num_cuts); CHECK(reduced[i].data); @@ -309,11 +307,11 @@ void HostSketchContainer::AllReduce( std::vector worker_segments(1, 0); // CSC pointer to sketches. std::vector sketches_scan((n_columns + 1) * world, 0); - std::vector global_sketches; + std::vector global_sketches; this->GatherSketchInfo(reduced, &worker_segments, &sketches_scan, &global_sketches); - std::vector final_sketches(n_columns); + std::vector final_sketches(n_columns); ParallelFor(n_columns, n_threads_, [&](auto fidx) { int32_t intermediate_num_cuts = num_cuts[fidx]; auto nbytes = @@ -321,8 +319,8 @@ void HostSketchContainer::AllReduce( for (int32_t i = 1; i < world + 1; ++i) { auto size = worker_segments.at(i) - worker_segments[i - 1]; - auto worker_sketches = Span{global_sketches}.subspan( - worker_segments[i - 1], size); + auto worker_sketches = + Span{global_sketches}.subspan(worker_segments[i - 1], size); auto worker_scan = Span(sketches_scan) .subspan((i - 1) * (n_columns + 1), (n_columns + 1)); @@ -330,8 +328,7 @@ void HostSketchContainer::AllReduce( auto worker_feature = worker_sketches.subspan( worker_scan[fidx], worker_scan[fidx + 1] - worker_scan[fidx]); CHECK(worker_feature.data()); - WQSummary summary(worker_feature.data(), - worker_feature.size()); + typename WQSketch::Summary summary(worker_feature.data(), worker_feature.size()); auto &out = final_sketches.at(fidx); out.Reduce(summary, nbytes); } @@ -342,10 +339,11 @@ void HostSketchContainer::AllReduce( monitor_.Stop(__func__); } -void AddCutPoint(WQuantileSketch::SummaryContainer const &summary, - int max_bin, HistogramCuts *cuts) { +template +void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin, + HistogramCuts *cuts) { size_t required_cuts = std::min(summary.size, static_cast(max_bin)); - auto& cut_values = cuts->cut_values_.HostVector(); + auto &cut_values = cuts->cut_values_.HostVector(); for (size_t i = 1; i < required_cuts; ++i) { bst_float cpt = summary.data[i].value; if (i == 1 || cpt > cut_values.back()) { @@ -361,20 +359,21 @@ void AddCategories(std::set const &categories, HistogramCuts *cuts) { } } -void HostSketchContainer::MakeCuts(HistogramCuts* cuts) { +template +void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { monitor_.Start(__func__); - std::vector reduced; + std::vector reduced; std::vector num_cuts; this->AllReduce(&reduced, &num_cuts); cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f); - std::vector final_summaries(reduced.size()); + std::vector final_summaries(reduced.size()); ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) { if (IsCat(feature_types_, fidx)) { return; } - WQSketch::SummaryContainer &a = final_summaries[fidx]; + typename WQSketch::SummaryContainer &a = final_summaries[fidx]; size_t max_num_bins = std::min(num_cuts[fidx], max_bins_); a.Reserve(max_num_bins + 1); CHECK(a.data); @@ -392,11 +391,11 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) { for (size_t fid = 0; fid < reduced.size(); ++fid) { size_t max_num_bins = std::min(num_cuts[fid], max_bins_); - WQSketch::SummaryContainer const& a = final_summaries[fid]; + typename WQSketch::SummaryContainer const& a = final_summaries[fid]; if (IsCat(feature_types_, fid)) { AddCategories(categories_.at(fid), cuts); } else { - AddCutPoint(a, max_num_bins, cuts); + AddCutPoint(a, max_num_bins, cuts); // push a value that is greater than anything const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid]; @@ -413,5 +412,64 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) { } monitor_.Stop(__func__); } + +template class SketchContainerImpl>; +template class SketchContainerImpl>; + +HostSketchContainer::HostSketchContainer(int32_t max_bins, MetaInfo const &info, + std::vector columns_size, bool use_group, + Span hessian, int32_t n_threads) + : SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group, + n_threads} { + monitor_.Init(__func__); + ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) { + auto n_bins = std::min(static_cast(max_bins_), columns_size_[i]); + n_bins = std::max(n_bins, static_cast(1)); + auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); + if (!IsCat(this->feature_types_, i)) { + sketches_[i].Init(columns_size_[i], eps); + sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2); + } + }); +} + +void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &info, + Span hessian) { + monitor_.Start(__func__); + // glue these conditions using ternary operator to avoid making data copies. + auto const &weights = + hessian.empty() ? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight + : info.weights_.HostVector()) // use sample weight + : MergeWeights(info, hessian, use_group_ind_, + n_threads_); // use hessian merged with group/sample weights + CHECK_EQ(weights.size(), info.num_row_); + + auto view = page.GetView(); + ParallelFor(view.Size(), n_threads_, [&](size_t fidx) { + auto column = view[fidx]; + auto &sketch = sketches_[fidx]; + sketch.Init(max_bins_); + // first pass + sketch.sum_total = 0.0; + for (auto c : column) { + sketch.sum_total += weights[c.index]; + } + // second pass + if (IsCat(feature_types_, fidx)) { + for (auto c : column) { + categories_[fidx].emplace(AsCat(c.fvalue)); + } + } else { + for (auto c : column) { + sketch.Push(c.fvalue, weights[c.index], max_bins_); + } + } + + if (!IsCat(feature_types_, fidx) && !column.empty()) { + sketch.Finalize(max_bins_); + } + }); + monitor_.Stop(__func__); +} } // namespace common } // namespace xgboost diff --git a/src/common/quantile.h b/src/common/quantile.h index b93a7c8c3..37720694f 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -702,11 +702,9 @@ class HistogramCuts; /*! * A sketch matrix storing sketches for each feature. */ -class HostSketchContainer { - public: - using WQSketch = WQuantileSketch; - - private: +template +class SketchContainerImpl { + protected: std::vector sketches_; std::vector> categories_; std::vector const feature_types_; @@ -724,7 +722,7 @@ class HostSketchContainer { * \param max_bins maximum number of bins for each feature. * \param use_group whether is assigned to group to data instance. */ - HostSketchContainer(std::vector columns_size, int32_t max_bins, + SketchContainerImpl(std::vector columns_size, int32_t max_bins, common::Span feature_types, bool use_group, int32_t n_threads); @@ -755,20 +753,139 @@ class HostSketchContainer { return group_ind; } // Gather sketches from all workers. - void GatherSketchInfo(std::vector const &reduced, + void GatherSketchInfo(std::vector const &reduced, std::vector *p_worker_segments, std::vector *p_sketches_scan, - std::vector *p_global_sketches); + std::vector *p_global_sketches); // Merge sketches from all workers. - void AllReduce(std::vector *p_reduced, - std::vector* p_num_cuts); + void AllReduce(std::vector *p_reduced, + std::vector *p_num_cuts); /* \brief Push a CSR matrix. */ - void PushRowPage(SparsePage const &page, MetaInfo const &info, - Span const hessian = {}); + void PushRowPage(SparsePage const &page, MetaInfo const &info, Span hessian = {}); void MakeCuts(HistogramCuts* cuts); }; + +class HostSketchContainer : public SketchContainerImpl> { + public: + using WQSketch = WQuantileSketch; + + public: + HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector columns_size, + bool use_group, Span hessian, int32_t n_threads); +}; + +/** + * \brief Quantile structure accepts sorted data, extracted from histmaker. + */ +struct SortedQuantile { + /*! \brief total sum of amount to be met */ + double sum_total{0.0}; + /*! \brief statistics used in the sketch */ + double rmin, wmin; + /*! \brief last seen feature value */ + bst_float last_fvalue; + /*! \brief current size of sketch */ + double next_goal; + // pointer to the sketch to put things in + common::WXQuantileSketch* sketch; + // initialize the space + inline void Init(unsigned max_size) { + next_goal = -1.0f; + rmin = wmin = 0.0f; + sketch->temp.Reserve(max_size + 1); + sketch->temp.size = 0; + } + /*! + * \brief push a new element to sketch + * \param fvalue feature value, comes in sorted ascending order + * \param w weight + * \param max_size + */ + inline void Push(bst_float fvalue, bst_float w, unsigned max_size) { + if (next_goal == -1.0f) { + next_goal = 0.0f; + last_fvalue = fvalue; + wmin = w; + return; + } + if (last_fvalue != fvalue) { + double rmax = rmin + wmin; + if (rmax >= next_goal && sketch->temp.size != max_size) { + if (sketch->temp.size == 0 || + last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) { + // push to sketch + sketch->temp.data[sketch->temp.size] = + common::WXQuantileSketch::Entry( + static_cast(rmin), static_cast(rmax), + static_cast(wmin), last_fvalue); + CHECK_LT(sketch->temp.size, max_size) << "invalid maximum size max_size=" << max_size + << ", stemp.size" << sketch->temp.size; + ++sketch->temp.size; + } + if (sketch->temp.size == max_size) { + next_goal = sum_total * 2.0f + 1e-5f; + } else { + next_goal = static_cast(sketch->temp.size * sum_total / max_size); + } + } else { + if (rmax >= next_goal) { + LOG(DEBUG) << "INFO: rmax=" << rmax << ", sum_total=" << sum_total + << ", naxt_goal=" << next_goal << ", size=" << sketch->temp.size; + } + } + rmin = rmax; + wmin = w; + last_fvalue = fvalue; + } else { + wmin += w; + } + } + + /*! \brief push final unfinished value to the sketch */ + inline void Finalize(unsigned max_size) { + double rmax = rmin + wmin; + if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) { + CHECK_LE(sketch->temp.size, max_size) + << "Finalize: invalid maximum size, max_size=" << max_size + << ", stemp.size=" << sketch->temp.size; + // push to sketch + sketch->temp.data[sketch->temp.size] = common::WXQuantileSketch::Entry( + static_cast(rmin), static_cast(rmax), static_cast(wmin), + last_fvalue); + ++sketch->temp.size; + } + sketch->PushTemp(); + } +}; + +class SortedSketchContainer : public SketchContainerImpl> { + std::vector sketches_; + using Super = SketchContainerImpl>; + + public: + explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info, + std::vector columns_size, bool use_group, + Span hessian, int32_t n_threads) + : SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group, + n_threads} { + monitor_.Init(__func__); + sketches_.resize(info.num_col_); + size_t i = 0; + for (auto &sketch : sketches_) { + sketch.sketch = &Super::sketches_[i]; + sketch.Init(max_bins_); + auto eps = 2.0 / max_bins; + sketch.sketch->Init(columns_size_[i], eps); + ++i; + } + } + /** + * \brief Push a sorted CSC page. + */ + void PushColPage(SparsePage const &page, MetaInfo const &info, Span hessian); +}; } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_QUANTILE_H_ diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index b0eea203e..a004f5231 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -118,7 +118,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, [](auto idx, auto) { return idx; }); } - common::ParallelFor(bst_omp_uint(nbins), n_threads, [&](bst_omp_uint idx) { + common::ParallelFor(nbins, n_threads, [&](bst_omp_uint idx) { for (int32_t tid = 0; tid < n_threads; ++tid) { hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch @@ -126,8 +126,11 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, }); } -void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins, common::Span hess) { - cut = common::SketchOnDMatrix(p_fmat, max_bins, hess); +void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, + common::Span hess) { + // We use sorted sketching for approx tree method since it's more efficient in + // computation time (but higher memory usage). + cut = common::SketchOnDMatrix(p_fmat, max_bins, sorted_sketch, hess); max_num_bins = max_bins; const int32_t nthread = omp_get_max_threads(); diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index a12ebfad6..f30e2267e 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -37,14 +37,14 @@ class GHistIndexMatrix { size_t base_rowid{0}; GHistIndexMatrix() = default; - GHistIndexMatrix(DMatrix* x, int32_t max_bin, common::Span hess = {}) { - this->Init(x, max_bin, hess); + GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, common::Span hess = {}) { + this->Init(x, max_bin, sorted_sketch, hess); } // Create a global histogram matrix, given cut - void Init(DMatrix* p_fmat, int max_num_bins, common::Span hess); - void Init(SparsePage const &page, common::Span ft, - common::HistogramCuts const &cuts, int32_t max_bins_per_feat, - bool is_dense, int32_t n_threads); + void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, common::Span hess); + void Init(SparsePage const& page, common::Span ft, + common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense, + int32_t n_threads); // specific method for sparse data as no possibility to reduce allocated memory template @@ -57,7 +57,9 @@ class GHistIndexMatrix { const size_t batch_size = batch.Size(); CHECK_LT(batch_size, offset_vec.size()); BinIdxType* index_data = index_data_span.data(); - common::ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) { + auto const& ptrs = cut.Ptrs(); + auto const& values = cut.Values(); + common::ParallelFor(batch_size, batch_threads, [&](omp_ulong i) { const int tid = omp_get_thread_num(); size_t ibegin = row_ptr[rbegin + i]; size_t iend = row_ptr[rbegin + i + 1]; @@ -71,7 +73,7 @@ class GHistIndexMatrix { index_data[ibegin + j] = get_offset(bin_idx, j); ++hit_count_tloc_[tid * nbins + bin_idx]; } else { - uint32_t idx = cut.SearchBin(inst[j]); + uint32_t idx = cut.SearchBin(inst[j].fvalue, inst[j].index, ptrs, values); index_data[ibegin + j] = get_offset(idx, j); ++hit_count_tloc_[tid * nbins + idx]; } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 66a5c0d3e..9ec343d91 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -94,10 +94,12 @@ BatchSet SimpleDMatrix::GetGradientIndex(const BatchParam& par if (!(batch_param_ != BatchParam{})) { CHECK(param != BatchParam{}) << "Batch parameter is not initialized."; } - if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) { + if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) { CHECK_GE(param.max_bin, 2); CHECK_EQ(param.gpu_id, -1); - gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.hess)); + // Used only by approx. + auto sorted_sketch = param.regen; + gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, sorted_sketch, param.hess)); batch_param_ = param; CHECK_EQ(batch_param_.hess.data(), param.hess.data()); } diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index db2e298df..bb95e0d99 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -159,12 +159,12 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches() { BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam& param) { CHECK_GE(param.max_bin, 2); - if (param.hess.empty()) { + if (param.hess.empty() && !param.regen) { // hist method doesn't support full external memory implementation, so we concatenate // all index here. if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) { this->InitializeSparsePage(); - ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin}); + ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.regen}); this->InitializeSparsePage(); batch_param_ = param; } @@ -175,20 +175,23 @@ BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam& auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); this->InitializeSparsePage(); - if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{})) { + if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{}) || + param.regen) { cache_info_.erase(id); MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); - auto cuts = common::SketchOnDMatrix(this, param.max_bin, param.hess); + // Use sorted sketch for approx. + auto sorted_sketch = param.regen; + auto cuts = common::SketchOnDMatrix(this, param.max_bin, sorted_sketch, param.hess); this->InitializeSparsePage(); // reset after use. batch_param_ = param; ghist_index_source_.reset(); CHECK_NE(cuts.Values().size(), 0); auto ft = this->info_.feature_types.ConstHostSpan(); - ghist_index_source_.reset(new GradientIndexPageSource( - this->missing_, this->ctx_.Threads(), this->Info().num_col_, - this->n_batches_, cache_info_.at(id), param, std::move(cuts), - this->IsDense(), param.max_bin, ft, sparse_page_source_)); + ghist_index_source_.reset( + new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_, + this->n_batches_, cache_info_.at(id), param, std::move(cuts), + this->IsDense(), param.max_bin, ft, sparse_page_source_)); } else { CHECK(ghist_index_source_); ghist_index_source_->Reset(); diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index 6a605ef04..c7c60b750 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -369,92 +369,7 @@ class BaseMaker: public TreeUpdater { } } } - /*! \brief common helper data structure to build sketch */ - struct SketchEntry { - /*! \brief total sum of amount to be met */ - double sum_total; - /*! \brief statistics used in the sketch */ - double rmin, wmin; - /*! \brief last seen feature value */ - bst_float last_fvalue; - /*! \brief current size of sketch */ - double next_goal; - // pointer to the sketch to put things in - common::WXQuantileSketch *sketch; - // initialize the space - inline void Init(unsigned max_size) { - next_goal = -1.0f; - rmin = wmin = 0.0f; - sketch->temp.Reserve(max_size + 1); - sketch->temp.size = 0; - } - /*! - * \brief push a new element to sketch - * \param fvalue feature value, comes in sorted ascending order - * \param w weight - * \param max_size - */ - inline void Push(bst_float fvalue, bst_float w, unsigned max_size) { - if (next_goal == -1.0f) { - next_goal = 0.0f; - last_fvalue = fvalue; - wmin = w; - return; - } - if (last_fvalue != fvalue) { - double rmax = rmin + wmin; - if (rmax >= next_goal && sketch->temp.size != max_size) { - if (sketch->temp.size == 0 || - last_fvalue > sketch->temp.data[sketch->temp.size-1].value) { - // push to sketch - sketch->temp.data[sketch->temp.size] = - common::WXQuantileSketch:: - Entry(static_cast(rmin), - static_cast(rmax), - static_cast(wmin), last_fvalue); - CHECK_LT(sketch->temp.size, max_size) - << "invalid maximum size max_size=" << max_size - << ", stemp.size" << sketch->temp.size; - ++sketch->temp.size; - } - if (sketch->temp.size == max_size) { - next_goal = sum_total * 2.0f + 1e-5f; - } else { - next_goal = static_cast(sketch->temp.size * sum_total / max_size); - } - } else { - if (rmax >= next_goal) { - LOG(TRACKER) << "INFO: rmax=" << rmax - << ", sum_total=" << sum_total - << ", naxt_goal=" << next_goal - << ", size=" << sketch->temp.size; - } - } - rmin = rmax; - wmin = w; - last_fvalue = fvalue; - } else { - wmin += w; - } - } - /*! \brief push final unfinished value to the sketch */ - inline void Finalize(unsigned max_size) { - double rmax = rmin + wmin; - if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size-1].value) { - CHECK_LE(sketch->temp.size, max_size) - << "Finalize: invalid maximum size, max_size=" << max_size - << ", stemp.size=" << sketch->temp.size; - // push to sketch - sketch->temp.data[sketch->temp.size] = - common::WXQuantileSketch:: - Entry(static_cast(rmin), - static_cast(rmax), - static_cast(wmin), last_fvalue); - ++sketch->temp.size; - } - sketch->PushTemp(); - } - }; + using SketchEntry = common::SortedQuantile; /*! \brief training parameter of tree grower */ TrainParam param_; /*! \brief queue of nodes to be expanded */ diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 5d5f03afd..379d364e1 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -14,7 +14,7 @@ TEST(DenseColumn, Test) { static_cast(std::numeric_limits::max()) + 2}; for (size_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin); + GHistIndexMatrix gmat(dmat.get(), max_num_bin, false); ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.2); @@ -61,7 +61,7 @@ TEST(SparseColumn, Test) { static_cast(std::numeric_limits::max()) + 2}; for (size_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin); + GHistIndexMatrix gmat(dmat.get(), max_num_bin, false); ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.5); switch (column_matrix.GetTypeSize()) { @@ -101,7 +101,7 @@ TEST(DenseColumnWithMissing, Test) { static_cast(std::numeric_limits::max()) + 2 }; for (size_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin); + GHistIndexMatrix gmat(dmat.get(), max_num_bin, false); ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.2); switch (column_matrix.GetTypeSize()) { @@ -130,7 +130,7 @@ void TestGHistIndexMatrixCreation(size_t nthreads) { /* This should create multiple sparse pages */ std::unique_ptr dmat{ CreateSparsePageDMatrix(kEntries) }; omp_set_num_threads(nthreads); - GHistIndexMatrix gmat(dmat.get(), 256); + GHistIndexMatrix gmat(dmat.get(), 256, false); } TEST(HistIndexCreationWithExternalMemory, Test) { diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index b59e994a2..350e544cf 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -223,13 +223,19 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) { auto w = GenerateRandomWeights(num_rows); dmat->Info().weights_.HostVector() = w; for (auto num_bins : bin_sizes) { - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); - ValidateCuts(cuts, dmat.get(), num_bins); + { + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, true); + ValidateCuts(cuts, dmat.get(), num_bins); + } + { + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, false); + ValidateCuts(cuts, dmat.get(), num_bins); + } } } } -TEST(HistUtil, QuantileWithHessian) { +void TestQuantileWithHessian(bool use_sorted) { int bin_sizes[] = {2, 16, 256, 512}; int sizes[] = {1000, 1500}; int num_columns = 5; @@ -243,13 +249,13 @@ TEST(HistUtil, QuantileWithHessian) { dmat->Info().weights_.HostVector() = w; for (auto num_bins : bin_sizes) { - HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, hessian); + HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, use_sorted, hessian); for (size_t i = 0; i < w.size(); ++i) { dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i]; } ValidateCuts(cuts_hess, dmat.get(), num_bins); - HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins); + HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins, use_sorted); ValidateCuts(cuts_wh, dmat.get(), num_bins); ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size()); @@ -262,6 +268,11 @@ TEST(HistUtil, QuantileWithHessian) { } } +TEST(HistUtil, QuantileWithHessian) { + TestQuantileWithHessian(true); + TestQuantileWithHessian(false); +} + TEST(HistUtil, DenseCutsExternalMemory) { int bin_sizes[] = {2, 16, 256, 512}; int sizes[] = {100, 1000, 1500}; @@ -292,7 +303,7 @@ TEST(HistUtil, IndexBinBound) { for (auto max_bin : bin_sizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - GHistIndexMatrix hmat(p_fmat.get(), max_bin); + GHistIndexMatrix hmat(p_fmat.get(), max_bin, false); EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize()); } @@ -315,7 +326,7 @@ TEST(HistUtil, IndexBinData) { for (auto max_bin : kBinSizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - GHistIndexMatrix hmat(p_fmat.get(), max_bin); + GHistIndexMatrix hmat(p_fmat.get(), max_bin, false); uint32_t* offsets = hmat.index.Offset(); EXPECT_EQ(hmat.index.Size(), kRows*kCols); switch (max_bin) { diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index eb4372500..5fd67f46d 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -19,7 +19,22 @@ TEST(Quantile, LoadBalance) { } CHECK_EQ(n_cols, kCols); } +namespace { +template +using ContainerType = std::conditional_t; +// Dispatch for push page. +void PushPage(SortedSketchContainer* container, SparsePage const& page, MetaInfo const& info, + Span hessian) { + container->PushColPage(page, info, hessian); +} +void PushPage(HostSketchContainer* container, SparsePage const& page, MetaInfo const& info, + Span hessian) { + container->PushRowPage(page, info, hessian); +} +} // anonymous namespace + +template void TestDistributedQuantile(size_t rows, size_t cols) { std::string msg {"Skipping AllReduce test"}; int32_t constexpr kWorkers = 4; @@ -48,12 +63,23 @@ void TestDistributedQuantile(size_t rows, size_t cols) { .Lower(.0f) .Upper(1.0f) .GenerateDMatrix(); - HostSketchContainer sketch_distributed( - column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false, - OmpGetNumThreads(0)); - for (auto const &page : m->GetBatches()) { - sketch_distributed.PushRowPage(page, m->Info()); + + std::vector hessian(rows, 1.0); + auto hess = Span{hessian}; + + ContainerType sketch_distributed(n_bins, m->Info(), column_size, false, hess, + OmpGetNumThreads(0)); + + if (use_column) { + for (auto const& page : m->GetBatches()) { + PushPage(&sketch_distributed, page, m->Info(), hess); + } + } else { + for (auto const& page : m->GetBatches()) { + PushPage(&sketch_distributed, page, m->Info(), hess); + } } + HistogramCuts distributed_cuts; sketch_distributed.MakeCuts(&distributed_cuts); @@ -61,17 +87,25 @@ void TestDistributedQuantile(size_t rows, size_t cols) { rabit::Finalize(); CHECK_EQ(rabit::GetWorldSize(), 1); std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); - HostSketchContainer sketch_on_single_node( - column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false, - OmpGetNumThreads(0)); + m->Info().num_row_ = world * rows; + ContainerType sketch_on_single_node(n_bins, m->Info(), column_size, false, hess, + OmpGetNumThreads(0)); + m->Info().num_row_ = rows; + for (auto rank = 0; rank < world; ++rank) { auto m = RandomDataGenerator{rows, cols, sparsity} .Seed(rank) .Lower(.0f) .Upper(1.0f) .GenerateDMatrix(); - for (auto const &page : m->GetBatches()) { - sketch_on_single_node.PushRowPage(page, m->Info()); + if (use_column) { + for (auto const& page : m->GetBatches()) { + PushPage(&sketch_on_single_node, page, m->Info(), hess); + } + } else { + for (auto const& page : m->GetBatches()) { + PushPage(&sketch_on_single_node, page, m->Info(), hess); + } } } @@ -87,7 +121,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) { ASSERT_EQ(sptrs.size(), dptrs.size()); for (size_t i = 0; i < sptrs.size(); ++i) { - ASSERT_EQ(sptrs[i], dptrs[i]); + ASSERT_EQ(sptrs[i], dptrs[i]) << i; } ASSERT_EQ(svals.size(), dvals.size()); @@ -104,14 +138,28 @@ void TestDistributedQuantile(size_t rows, size_t cols) { TEST(Quantile, DistributedBasic) { #if defined(__unix__) constexpr size_t kRows = 10, kCols = 10; - TestDistributedQuantile(kRows, kCols); + TestDistributedQuantile(kRows, kCols); #endif } TEST(Quantile, Distributed) { #if defined(__unix__) - constexpr size_t kRows = 1000, kCols = 200; - TestDistributedQuantile(kRows, kCols); + constexpr size_t kRows = 4000, kCols = 200; + TestDistributedQuantile(kRows, kCols); +#endif +} + +TEST(Quantile, SortedDistributedBasic) { +#if defined(__unix__) + constexpr size_t kRows = 10, kCols = 10; + TestDistributedQuantile(kRows, kCols); +#endif +} + +TEST(Quantile, SortedDistributed) { +#if defined(__unix__) + constexpr size_t kRows = 4000, kCols = 200; + TestDistributedQuantile(kRows, kCols); #endif } diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index 2dcb5ed1a..bbb0f6de4 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -36,7 +36,7 @@ TEST(GradientIndex, FromCategoricalBasic) { BatchParam p(0, max_bins); GHistIndexMatrix gidx; - gidx.Init(m.get(), max_bins, {}); + gidx.Init(m.get(), max_bins, false, {}); auto x_copy = x; std::sort(x_copy.begin(), x_copy.end()); diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 59d1ef232..d7adde257 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -29,7 +29,7 @@ template void TestEvaluateSplits() { size_t constexpr kMaxBins = 4; // dense, no missing values - GHistIndexMatrix gmat(dmat.get(), kMaxBins); + GHistIndexMatrix gmat(dmat.get(), kMaxBins, false); common::RowSetCollection row_set_collection; std::vector &row_indices = *row_set_collection.Data(); row_indices.resize(kRows); diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 534dd2a9e..006cbf30d 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -162,7 +162,7 @@ class QuantileHistMock : public QuantileHistMaker { // kNRows samples with kNCols features auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), kMaxBins); + GHistIndexMatrix gmat(dmat.get(), kMaxBins, false); ColumnMatrix cm; // treat everything as dense, as this is what we intend to test here @@ -253,7 +253,7 @@ class QuantileHistMock : public QuantileHistMaker { void TestInitData() { size_t constexpr kMaxBins = 4; - GHistIndexMatrix gmat(dmat_.get(), kMaxBins); + GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false); RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -270,7 +270,7 @@ class QuantileHistMock : public QuantileHistMaker { void TestInitDataSampling() { size_t constexpr kMaxBins = 4; - GHistIndexMatrix gmat(dmat_.get(), kMaxBins); + GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false); RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_);