diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index f78a7ed09..24b99ed4a 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -56,15 +56,15 @@ template class HistEvaluator { // a non-missing value for the particular feature fid. template GradStats EnumerateSplit( - const GHistIndexMatrix &gmat, const common::GHistRow &hist, + common::HistogramCuts const &cut, const common::GHistRow &hist, const NodeEntry &snode, SplitEntry *p_best, bst_feature_t fidx, bst_node_t nidx, TreeEvaluator::SplitEvaluator const &evaluator) const { static_assert(d_step == +1 || d_step == -1, "Invalid step."); // aliases - const std::vector &cut_ptr = gmat.cut.Ptrs(); - const std::vector &cut_val = gmat.cut.Values(); + const std::vector &cut_ptr = cut.Ptrs(); + const std::vector &cut_val = cut.Values(); // statistics on both sides of split GradStats c; @@ -116,7 +116,7 @@ template class HistEvaluator { snode.root_gain); if (i == imin) { // for leftmost bin, left bound is the smallest feature value - split_pt = gmat.cut.MinValues()[fidx]; + split_pt = cut.MinValues()[fidx]; } else { split_pt = cut_val[i - 1]; } @@ -132,7 +132,7 @@ template class HistEvaluator { public: void EvaluateSplits(const common::HistCollection &hist, - GHistIndexMatrix const &gidx, const RegTree &tree, + common::HistogramCuts const &cut, const RegTree &tree, std::vector* p_entries) { auto& entries = *p_entries; // All nodes are on the same level, so we can store the shared ptr. @@ -168,10 +168,10 @@ template class HistEvaluator { for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) { auto fidx = features_set[fidx_in_set]; if (interaction_constraints_.Query(nidx, fidx)) { - auto grad_stats = EnumerateSplit<+1>(gidx, histogram, snode_[nidx], + auto grad_stats = EnumerateSplit<+1>(cut, histogram, snode_[nidx], best, fidx, nidx, evaluator); if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { - EnumerateSplit<-1>(gidx, histogram, snode_[nidx], best, fidx, nidx, + EnumerateSplit<-1>(cut, histogram, snode_[nidx], best, fidx, nidx, evaluator); } } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 6d426b2f7..19c300b30 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -170,7 +170,8 @@ void QuantileHistMaker::Builder::InitRoot( builder_monitor_.Start("EvaluateSplits"); for (auto const &gmat : p_fmat->GetBatches( BatchParam{GenericParameter::kCpuId, param_.max_bin})) { - evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat, *p_tree, &entries); + evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, *p_tree, &entries); + break; } builder_monitor_.Stop("EvaluateSplits"); node = entries.front(); @@ -271,7 +272,7 @@ void QuantileHistMaker::Builder::ExpandTree( } builder_monitor_.Start("EvaluateSplits"); - evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), gmat, + evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), gmat.cut, *p_tree, &nodes_to_evaluate); builder_monitor_.Stop("EvaluateSplits"); diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index 4bdf34ab2..2c19b9e58 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -13,7 +13,8 @@ TEST(GradientIndex, ExternalMemory) { std::unique_ptr dmat = CreateSparsePageDMatrix(10000); std::vector base_rowids; std::vector hessian(dmat->Info().num_row_, 1); - for (auto const& page : dmat->GetBatches({0, 64, hessian})) { + for (auto const &page : dmat->GetBatches( + {GenericParameter::kCpuId, 64, hessian})) { base_rowids.push_back(page.base_rowid); } size_t i = 0; diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index c9228edf9..cb0171269 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -58,7 +58,7 @@ template void TestEvaluateSplits() { entries.front().depth = 0; evaluator.InitRoot(GradStats{total_gpair}); - evaluator.EvaluateSplits(hist, gmat, tree, &entries); + evaluator.EvaluateSplits(hist, gmat.cut, tree, &entries); auto best_loss_chg = evaluator.Evaluator().CalcSplitGain(