From e060519d4f0f43809be34ec72c4041d6e89995e7 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 26 Jan 2022 21:41:30 +0800 Subject: [PATCH] Avoid regenerating the gradient index for approx. (#7591) --- src/data/gradient_index.cc | 3 +-- src/data/gradient_index.h | 11 +++++++++++ src/data/simple_dmatrix.cc | 3 ++- src/data/sparse_page_dmatrix.cc | 6 +++--- src/data/sparse_page_dmatrix.cu | 2 +- src/tree/hist/evaluate_splits.h | 1 + src/tree/updater_approx.cc | 32 +++++++++++++++++++++----------- 7 files changed, 40 insertions(+), 18 deletions(-) diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index c68276e9a..4f815fd04 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -17,8 +17,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, // block is parallelized on anything other than the batch/block size, // it should be reassigned const size_t batch_threads = - std::max(size_t(1), std::min(batch.Size(), - static_cast(n_threads))); + std::max(static_cast(1), std::min(batch.Size(), static_cast(n_threads))); auto page = batch.GetView(); common::MemStackAllocator partial_sums(batch_threads); size_t *p_part = partial_sums.Get(); diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 58f3a0753..76062b57c 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -108,5 +108,16 @@ class GHistIndexMatrix { std::vector hit_count_tloc_; bool isDense_; }; + +/** + * \brief Should we regenerate the gradient index? + * + * \param old Parameter stored in DMatrix. + * \param p New parameter passed in by caller. + */ +inline bool RegenGHist(BatchParam old, BatchParam p) { + // parameter is renewed or caller requests a regen + return p.regen || (old.gpu_id != p.gpu_id || old.max_bin != p.max_bin); +} } // namespace xgboost #endif // XGBOOST_DATA_GRADIENT_INDEX_H_ diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index d447f14ce..3e1e1de79 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -94,7 +94,8 @@ BatchSet SimpleDMatrix::GetGradientIndex(const BatchParam& par if (!(batch_param_ != BatchParam{})) { CHECK(param != BatchParam{}) << "Batch parameter is not initialized."; } - if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) { + if (!gradient_index_ || RegenGHist(batch_param_, param)) { + LOG(INFO) << "Generating new Gradient Index."; CHECK_GE(param.max_bin, 2); CHECK_EQ(param.gpu_id, -1); // Used only by approx. diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 0ce3b8c38..22ad0f85d 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -157,7 +157,7 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches() { return BatchSet(BatchIterator(begin_iter)); } -BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam& param) { +BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam ¶m) { CHECK_GE(param.max_bin, 2); if (param.hess.empty() && !param.regen) { // hist method doesn't support full external memory implementation, so we concatenate @@ -176,10 +176,10 @@ BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam& auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); this->InitializeSparsePage(); - if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{}) || - param.regen) { + if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) { cache_info_.erase(id); MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); + LOG(INFO) << "Generating new Gradient Index."; // Use sorted sketch for approx. auto sorted_sketch = param.regen; auto cuts = diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 0ffc4c45a..82e1f3ce0 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -14,7 +14,7 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(const BatchParam& par auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); size_t row_stride = 0; this->InitializeSparsePage(); - if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{})) { + if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) { // reinitialize the cache cache_info_.erase(id); MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 2d3a44226..9fde7ee38 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -351,6 +351,7 @@ template class HistEvaluator { auto Evaluator() const { return tree_evaluator_.GetEvaluator(); } auto const& Stats() const { return snode_; } + auto Task() const { return task_; } float InitRoot(GradStats const& root_sum) { snode_.resize(1); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 9ae8c12ae..6acc096e0 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -26,6 +26,19 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_approx); +namespace { +// Return the BatchParam used by DMatrix. +template +auto BatchSpec(TrainParam const &p, common::Span hess, + HistEvaluator const &evaluator) { + return BatchParam{GenericParameter::kCpuId, p.max_bin, hess, !evaluator.Task().const_hess}; +} + +auto BatchSpec(TrainParam const &p, common::Span hess) { + return BatchParam{GenericParameter::kCpuId, p.max_bin, hess, false}; +} +} // anonymous namespace + template class GloablApproxBuilder { protected: @@ -46,12 +59,13 @@ class GloablApproxBuilder { public: void InitData(DMatrix *p_fmat, common::Span hess) { monitor_->Start(__func__); + n_batches_ = 0; int32_t n_total_bins = 0; partitioner_.clear(); // Generating the GHistIndexMatrix is quite slow, is there a way to speed it up? - for (auto const &page : p_fmat->GetBatches( - {GenericParameter::kCpuId, param_.max_bin, hess, true})) { + for (auto const &page : + p_fmat->GetBatches(BatchSpec(param_, hess, evaluator_))) { if (n_total_bins == 0) { n_total_bins = page.cut.TotalBins(); feature_values_ = page.cut; @@ -62,9 +76,8 @@ class GloablApproxBuilder { n_batches_++; } - histogram_builder_.Reset(n_total_bins, - BatchParam{GenericParameter::kCpuId, param_.max_bin, hess}, - ctx_->Threads(), n_batches_, rabit::IsDistributed()); + histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_, + rabit::IsDistributed()); monitor_->Stop(__func__); } @@ -82,8 +95,7 @@ class GloablApproxBuilder { std::vector nodes{best}; size_t i = 0; auto space = this->ConstructHistSpace(nodes); - for (auto const &page : - p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { + for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess))) { histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes, {}, gpair); i++; @@ -175,8 +187,7 @@ class GloablApproxBuilder { size_t i = 0; auto space = this->ConstructHistSpace(nodes_to_build); - for (auto const &page : - p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { + for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess))) { histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes_to_build, nodes_to_sub, gpair); i++; @@ -225,8 +236,7 @@ class GloablApproxBuilder { monitor_->Start("UpdatePosition"); size_t i = 0; - for (auto const &page : - p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { + for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess))) { partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree); i++; }