/*! * Copyright 2017-2021 by Contributors * \file updater_quantile_hist.cc * \brief use quantized feature values to construct a tree * \author Philip Cho, Tianqi Checn, Egor Smirnov */ #include #include #include #include #include #include #include #include #include #include #include #include "xgboost/logging.h" #include "xgboost/tree_updater.h" #include "constraints.h" #include "param.h" #include "./updater_quantile_hist.h" #include "./split_evaluator.h" #include "../common/random.h" #include "../common/hist_util.h" #include "../common/row_set.h" #include "../common/column_matrix.h" #include "../common/threading_utils.h" namespace xgboost { namespace tree { DMLC_REGISTRY_FILE_TAG(updater_quantile_hist); DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); void QuantileHistMaker::Configure(const Args& args) { // initialize pruner if (!pruner_) { pruner_.reset(TreeUpdater::Create("prune", tparam_)); } pruner_->Configure(args); param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args); } template void QuantileHistMaker::SetBuilder(const size_t n_trees, std::unique_ptr>* builder, DMatrix *dmat) { builder->reset(new Builder( n_trees, param_, std::move(pruner_), int_constraint_, dmat)); if (rabit::IsDistributed()) { (*builder)->SetHistSynchronizer(new DistributedHistSynchronizer()); (*builder)->SetHistRowsAdder(new DistributedHistRowsAdder()); } else { (*builder)->SetHistSynchronizer(new BatchHistSynchronizer()); (*builder)->SetHistRowsAdder(new BatchHistRowsAdder()); } } template void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr>& builder, HostDeviceVector *gpair, DMatrix *dmat, const std::vector &trees) { for (auto tree : trees) { builder->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree); } } void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, const std::vector &trees) { if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { updater_monitor_.Start("GmatInitialization"); gmat_.Init(dmat, static_cast(param_.max_bin)); column_matrix_.Init(gmat_, param_.sparse_threshold); if (param_.enable_feature_grouping > 0) { gmatb_.Init(gmat_, column_matrix_, param_); } updater_monitor_.Stop("GmatInitialization"); // A proper solution is puting cut matrix in DMatrix, see: // https://github.com/dmlc/xgboost/issues/5143 is_gmat_initialized_ = true; } // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); int_constraint_.Configure(param_, dmat->Info().num_col_); // build tree const size_t n_trees = trees.size(); if (hist_maker_param_.single_precision_histogram) { if (!float_builder_) { SetBuilder(n_trees, &float_builder_, dmat); } CallBuilderUpdate(float_builder_, gpair, dmat, trees); } else { if (!double_builder_) { SetBuilder(n_trees, &double_builder_, dmat); } CallBuilderUpdate(double_builder_, gpair, dmat, trees); } param_.learning_rate = lr; p_last_dmat_ = dmat; } bool QuantileHistMaker::UpdatePredictionCache( const DMatrix* data, VectorView out_preds) { if (hist_maker_param_.single_precision_histogram && float_builder_) { return float_builder_->UpdatePredictionCache(data, out_preds); } else if (double_builder_) { return double_builder_->UpdatePredictionCache(data, out_preds); } else { return false; } } template void BatchHistSynchronizer::SyncHistograms(BuilderT *builder, int, int, RegTree *p_tree) { builder->builder_monitor_.Start("SyncHistograms"); const size_t nbins = builder->hist_builder_.GetNumBins(); common::BlockedSpace2d space(builder->nodes_for_explicit_hist_build_.size(), [&](size_t) { return nbins; }, 1024); common::ParallelFor2d(space, builder->nthread_, [&](size_t node, common::Range1d r) { const auto& entry = builder->nodes_for_explicit_hist_build_[node]; auto this_hist = builder->hist_[entry.nid]; // Merging histograms from each thread into once builder->hist_buffer_.ReduceHist(node, r.begin(), r.end()); if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { const size_t parent_id = (*p_tree)[entry.nid].Parent(); auto parent_hist = builder->hist_[parent_id]; auto sibling_hist = builder->hist_[entry.sibling_nid]; SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end()); } }); builder->builder_monitor_.Stop("SyncHistograms"); } template void DistributedHistSynchronizer::SyncHistograms(BuilderT* builder, int starting_index, int sync_count, RegTree *p_tree) { builder->builder_monitor_.Start("SyncHistograms"); const size_t nbins = builder->hist_builder_.GetNumBins(); common::BlockedSpace2d space(builder->nodes_for_explicit_hist_build_.size(), [&](size_t) { return nbins; }, 1024); common::ParallelFor2d(space, builder->nthread_, [&](size_t node, common::Range1d r) { const auto& entry = builder->nodes_for_explicit_hist_build_[node]; auto this_hist = builder->hist_[entry.nid]; // Merging histograms from each thread into once builder->hist_buffer_.ReduceHist(node, r.begin(), r.end()); // Store posible parent node auto this_local = builder->hist_local_worker_[entry.nid]; CopyHist(this_local, this_hist, r.begin(), r.end()); if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { const size_t parent_id = (*p_tree)[entry.nid].Parent(); auto parent_hist = builder->hist_local_worker_[parent_id]; auto sibling_hist = builder->hist_[entry.sibling_nid]; SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end()); // Store posible parent node auto sibling_local = builder->hist_local_worker_[entry.sibling_nid]; CopyHist(sibling_local, sibling_hist, r.begin(), r.end()); } }); builder->builder_monitor_.Start("SyncHistogramsAllreduce"); builder->histred_.Allreduce(builder->hist_[starting_index].data(), builder->hist_builder_.GetNumBins() * sync_count); builder->builder_monitor_.Stop("SyncHistogramsAllreduce"); ParallelSubtractionHist(builder, space, builder->nodes_for_explicit_hist_build_, p_tree); common::BlockedSpace2d space2(builder->nodes_for_subtraction_trick_.size(), [&](size_t) { return nbins; }, 1024); ParallelSubtractionHist(builder, space2, builder->nodes_for_subtraction_trick_, p_tree); builder->builder_monitor_.Stop("SyncHistograms"); } template void DistributedHistSynchronizer::ParallelSubtractionHist( BuilderT* builder, const common::BlockedSpace2d& space, const std::vector& nodes, const RegTree * p_tree) { common::ParallelFor2d(space, builder->nthread_, [&](size_t node, common::Range1d r) { const auto& entry = nodes[node]; if (!((*p_tree)[entry.nid].IsLeftChild())) { auto this_hist = builder->hist_[entry.nid]; if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { auto parent_hist = builder->hist_[(*p_tree)[entry.nid].Parent()]; auto sibling_hist = builder->hist_[entry.sibling_nid]; SubtractionHist(this_hist, parent_hist, sibling_hist, r.begin(), r.end()); } } }); } template void BatchHistRowsAdder::AddHistRows(BuilderT *builder, int *starting_index, int *sync_count, RegTree *) { builder->builder_monitor_.Start("AddHistRows"); for (auto const& entry : builder->nodes_for_explicit_hist_build_) { int nid = entry.nid; builder->hist_.AddHistRow(nid); (*starting_index) = std::min(nid, (*starting_index)); } (*sync_count) = builder->nodes_for_explicit_hist_build_.size(); for (auto const& node : builder->nodes_for_subtraction_trick_) { builder->hist_.AddHistRow(node.nid); } builder->hist_.AllocateAllData(); builder->builder_monitor_.Stop("AddHistRows"); } template void DistributedHistRowsAdder::AddHistRows(BuilderT *builder, int *starting_index, int *sync_count, RegTree *p_tree) { builder->builder_monitor_.Start("AddHistRows"); const size_t explicit_size = builder->nodes_for_explicit_hist_build_.size(); const size_t subtaction_size = builder->nodes_for_subtraction_trick_.size(); std::vector merged_node_ids(explicit_size + subtaction_size); for (size_t i = 0; i < explicit_size; ++i) { merged_node_ids[i] = builder->nodes_for_explicit_hist_build_[i].nid; } for (size_t i = 0; i < subtaction_size; ++i) { merged_node_ids[explicit_size + i] = builder->nodes_for_subtraction_trick_[i].nid; } std::sort(merged_node_ids.begin(), merged_node_ids.end()); int n_left = 0; for (auto const& nid : merged_node_ids) { if ((*p_tree)[nid].IsLeftChild()) { builder->hist_.AddHistRow(nid); (*starting_index) = std::min(nid, (*starting_index)); n_left++; builder->hist_local_worker_.AddHistRow(nid); } } for (auto const& nid : merged_node_ids) { if (!((*p_tree)[nid].IsLeftChild())) { builder->hist_.AddHistRow(nid); builder->hist_local_worker_.AddHistRow(nid); } } builder->hist_.AllocateAllData(); builder->hist_local_worker_.AllocateAllData(); (*sync_count) = std::max(1, n_left); builder->builder_monitor_.Stop("AddHistRows"); } template void QuantileHistMaker::Builder::SetHistSynchronizer( HistSynchronizer *sync) { hist_synchronizer_.reset(sync); } template void QuantileHistMaker::Builder::SetHistRowsAdder( HistRowsAdder *adder) { hist_rows_adder_.reset(adder); } template void QuantileHistMaker::Builder::BuildHistogramsLossGuide( ExpandEntry entry, const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, RegTree *p_tree, const std::vector &gpair_h) { nodes_for_explicit_hist_build_.clear(); nodes_for_subtraction_trick_.clear(); nodes_for_explicit_hist_build_.push_back(entry); if (entry.sibling_nid > -1) { nodes_for_subtraction_trick_.emplace_back(entry.sibling_nid, entry.nid, p_tree->GetDepth(entry.sibling_nid), 0.0f, 0); } int starting_index = std::numeric_limits::max(); int sync_count = 0; hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, p_tree); BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h); hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree); } template void QuantileHistMaker::Builder::BuildLocalHistograms( const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, RegTree *p_tree, const std::vector &gpair_h) { builder_monitor_.Start("BuildLocalHistograms"); const size_t n_nodes = nodes_for_explicit_hist_build_.size(); // create space of size (# rows in each node) common::BlockedSpace2d space(n_nodes, [&](size_t node) { const int32_t nid = nodes_for_explicit_hist_build_[node].nid; return row_set_collection_[nid].Size(); }, 256); std::vector target_hists(n_nodes); for (size_t i = 0; i < n_nodes; ++i) { const int32_t nid = nodes_for_explicit_hist_build_[i].nid; target_hists[i] = hist_[nid]; } hist_buffer_.Reset(this->nthread_, n_nodes, space, target_hists); // Parallel processing by nodes and data in each node common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) { const auto tid = static_cast(omp_get_thread_num()); const int32_t nid = nodes_for_explicit_hist_build_[nid_in_set].nid; auto start_of_row_set = row_set_collection_[nid].begin; auto rid_set = RowSetCollection::Elem(start_of_row_set + r.begin(), start_of_row_set + r.end(), nid); BuildHist(gpair_h, rid_set, gmat, gmatb, hist_buffer_.GetInitializedHist(tid, nid_in_set)); }); builder_monitor_.Stop("BuildLocalHistograms"); } template void QuantileHistMaker::Builder::BuildNodeStats( const GHistIndexMatrix &gmat, DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { builder_monitor_.Start("BuildNodeStats"); for (auto const& entry : qexpand_depth_wise_) { int nid = entry.nid; this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree); // add constraints if (!(*p_tree)[nid].IsLeftChild() && !(*p_tree)[nid].IsRoot()) { // it's a right child auto parent_id = (*p_tree)[nid].Parent(); auto left_sibling_id = (*p_tree)[parent_id].LeftChild(); auto parent_split_feature_id = snode_[parent_id].best.SplitIndex(); tree_evaluator_.AddSplit( parent_id, left_sibling_id, nid, parent_split_feature_id, snode_[left_sibling_id].weight, snode_[nid].weight); interaction_constraints_.Split(parent_id, parent_split_feature_id, left_sibling_id, nid); } } builder_monitor_.Stop("BuildNodeStats"); } template void QuantileHistMaker::Builder::AddSplitsToTree( const GHistIndexMatrix &gmat, RegTree *p_tree, int *num_leaves, int depth, unsigned *timestamp, std::vector* nodes_for_apply_split, std::vector* temp_qexpand_depth) { auto evaluator = tree_evaluator_.GetEvaluator(); for (auto const& entry : qexpand_depth_wise_) { int nid = entry.nid; if (snode_[nid].best.loss_chg < kRtEps || (param_.max_depth > 0 && depth == param_.max_depth) || (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) { (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); } else { nodes_for_apply_split->push_back(entry); NodeEntry& e = snode_[nid]; bst_float left_leaf_weight = evaluator.CalcWeight(nid, param_, GradStats{e.best.left_sum}) * param_.learning_rate; bst_float right_leaf_weight = evaluator.CalcWeight(nid, param_, GradStats{e.best.right_sum}) * param_.learning_rate; p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft(), e.weight, left_leaf_weight, right_leaf_weight, e.best.loss_chg, e.stats.GetHess(), e.best.left_sum.GetHess(), e.best.right_sum.GetHess()); int left_id = (*p_tree)[nid].LeftChild(); int right_id = (*p_tree)[nid].RightChild(); temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id, p_tree->GetDepth(left_id), 0.0, (*timestamp)++)); temp_qexpand_depth->push_back(ExpandEntry(right_id, left_id, p_tree->GetDepth(right_id), 0.0, (*timestamp)++)); // - 1 parent + 2 new children (*num_leaves)++; } } } template void QuantileHistMaker::Builder::EvaluateAndApplySplits( const GHistIndexMatrix &gmat, const ColumnMatrix &column_matrix, RegTree *p_tree, int *num_leaves, int depth, unsigned *timestamp, std::vector *temp_qexpand_depth) { EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree); std::vector nodes_for_apply_split; AddSplitsToTree(gmat, p_tree, num_leaves, depth, timestamp, &nodes_for_apply_split, temp_qexpand_depth); ApplySplit(nodes_for_apply_split, gmat, column_matrix, hist_, p_tree); } // Split nodes to 2 sets depending on amount of rows in each node // Histograms for small nodes will be built explicitly // Histograms for big nodes will be built by 'Subtraction Trick' // Exception: in distributed setting, we always build the histogram for the left child node // and use 'Subtraction Trick' to built the histogram for the right child node. // This ensures that the workers operate on the same set of tree nodes. template void QuantileHistMaker::Builder::SplitSiblings( const std::vector &nodes, std::vector *small_siblings, std::vector *big_siblings, RegTree *p_tree) { builder_monitor_.Start("SplitSiblings"); for (auto const& entry : nodes) { int nid = entry.nid; RegTree::Node &node = (*p_tree)[nid]; if (node.IsRoot()) { small_siblings->push_back(entry); } else { const int32_t left_id = (*p_tree)[node.Parent()].LeftChild(); const int32_t right_id = (*p_tree)[node.Parent()].RightChild(); if (nid == left_id && row_set_collection_[left_id ].Size() < row_set_collection_[right_id].Size()) { small_siblings->push_back(entry); } else if (nid == right_id && row_set_collection_[right_id].Size() <= row_set_collection_[left_id ].Size()) { small_siblings->push_back(entry); } else { big_siblings->push_back(entry); } } } builder_monitor_.Stop("SplitSiblings"); } template void QuantileHistMaker::Builder::ExpandWithDepthWise( const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, const ColumnMatrix &column_matrix, DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { unsigned timestamp = 0; int num_leaves = 0; // in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid, p_tree->GetDepth(ExpandEntry::kRootNid), 0.0, timestamp++)); ++num_leaves; for (int depth = 0; depth < param_.max_depth + 1; depth++) { int starting_index = std::numeric_limits::max(); int sync_count = 0; std::vector temp_qexpand_depth; SplitSiblings(qexpand_depth_wise_, &nodes_for_explicit_hist_build_, &nodes_for_subtraction_trick_, p_tree); hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, p_tree); BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h); hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree); BuildNodeStats(gmat, p_fmat, p_tree, gpair_h); EvaluateAndApplySplits(gmat, column_matrix, p_tree, &num_leaves, depth, ×tamp, &temp_qexpand_depth); // clean up qexpand_depth_wise_.clear(); nodes_for_subtraction_trick_.clear(); nodes_for_explicit_hist_build_.clear(); if (temp_qexpand_depth.empty()) { break; } else { qexpand_depth_wise_ = temp_qexpand_depth; temp_qexpand_depth.clear(); } } } template void QuantileHistMaker::Builder::ExpandWithLossGuide( const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, const ColumnMatrix& column_matrix, DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h) { builder_monitor_.Start("ExpandWithLossGuide"); unsigned timestamp = 0; int num_leaves = 0; ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid, p_tree->GetDepth(0), 0.0f, timestamp++); BuildHistogramsLossGuide(node, gmat, gmatb, p_tree, gpair_h); this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair_h, *p_fmat, *p_tree); this->EvaluateSplits({node}, gmat, hist_, *p_tree); node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg; qexpand_loss_guided_->push(node); ++num_leaves; while (!qexpand_loss_guided_->empty()) { const ExpandEntry candidate = qexpand_loss_guided_->top(); const int nid = candidate.nid; qexpand_loss_guided_->pop(); if (candidate.IsValid(param_, num_leaves)) { (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); } else { auto evaluator = tree_evaluator_.GetEvaluator(); NodeEntry& e = snode_[nid]; bst_float left_leaf_weight = evaluator.CalcWeight(nid, param_, GradStats{e.best.left_sum}) * param_.learning_rate; bst_float right_leaf_weight = evaluator.CalcWeight(nid, param_, GradStats{e.best.right_sum}) * param_.learning_rate; p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft(), e.weight, left_leaf_weight, right_leaf_weight, e.best.loss_chg, e.stats.GetHess(), e.best.left_sum.GetHess(), e.best.right_sum.GetHess()); this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree); const int cleft = (*p_tree)[nid].LeftChild(); const int cright = (*p_tree)[nid].RightChild(); ExpandEntry left_node(cleft, cright, p_tree->GetDepth(cleft), 0.0f, timestamp++); ExpandEntry right_node(cright, cleft, p_tree->GetDepth(cright), 0.0f, timestamp++); if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { BuildHistogramsLossGuide(left_node, gmat, gmatb, p_tree, gpair_h); } else { BuildHistogramsLossGuide(right_node, gmat, gmatb, p_tree, gpair_h); } this->InitNewNode(cleft, gmat, gpair_h, *p_fmat, *p_tree); this->InitNewNode(cright, gmat, gpair_h, *p_fmat, *p_tree); bst_uint featureid = snode_[nid].best.SplitIndex(); tree_evaluator_.AddSplit(nid, cleft, cright, featureid, snode_[cleft].weight, snode_[cright].weight); interaction_constraints_.Split(nid, featureid, cleft, cright); this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree); left_node.loss_chg = snode_[cleft].best.loss_chg; right_node.loss_chg = snode_[cright].best.loss_chg; qexpand_loss_guided_->push(left_node); qexpand_loss_guided_->push(right_node); ++num_leaves; // give two and take one, as parent is no longer a leaf } } builder_monitor_.Stop("ExpandWithLossGuide"); } template void QuantileHistMaker::Builder::Update( const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, const ColumnMatrix &column_matrix, HostDeviceVector *gpair, DMatrix *p_fmat, RegTree *p_tree) { builder_monitor_.Start("Update"); std::vector* gpair_ptr = &(gpair->HostVector()); // in case 'num_parallel_trees != 1' no posibility to change initial gpair if (GetNumberOfTrees() != 1) { gpair_local_.resize(gpair_ptr->size()); gpair_local_ = *gpair_ptr; gpair_ptr = &gpair_local_; } tree_evaluator_ = TreeEvaluator(param_, p_fmat->Info().num_col_, GenericParameter::kCpuId); interaction_constraints_.Reset(); p_last_fmat_mutable_ = p_fmat; this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr); if (param_.grow_policy == TrainParam::kLossGuide) { ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, *gpair_ptr); } else { ExpandWithDepthWise(gmat, gmatb, column_matrix, p_fmat, p_tree, *gpair_ptr); } for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; p_tree->Stat(nid).base_weight = snode_[nid].weight; p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.GetHess()); } pruner_->Update(gpair, p_fmat, std::vector{p_tree}); builder_monitor_.Stop("Update"); } template bool QuantileHistMaker::Builder::UpdatePredictionCache( const DMatrix* data, VectorView out_preds) { // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // conjunction with Update(). if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ || p_last_fmat_ != p_last_fmat_mutable_) { return false; } builder_monitor_.Start("UpdatePredictionCache"); CHECK_GT(out_preds.Size(), 0U); size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); common::BlockedSpace2d space(n_nodes, [&](size_t node) { return row_set_collection_[node].Size(); }, 1024); CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId); common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) { const RowSetCollection::Elem rowset = row_set_collection_[node]; if (rowset.begin != nullptr && rowset.end != nullptr) { int nid = rowset.node_id; bst_float leaf_value; // if a node is marked as deleted by the pruner, traverse upward to locate // a non-deleted leaf. if ((*p_last_tree_)[nid].IsDeleted()) { while ((*p_last_tree_)[nid].IsDeleted()) { nid = (*p_last_tree_)[nid].Parent(); } CHECK((*p_last_tree_)[nid].IsLeaf()); } leaf_value = (*p_last_tree_)[nid].LeafValue(); for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { out_preds[*it] += leaf_value; } } }); builder_monitor_.Stop("UpdatePredictionCache"); return true; } template void QuantileHistMaker::Builder::InitSampling(const DMatrix& fmat, std::vector* gpair, std::vector* row_indices) { const auto& info = fmat.Info(); auto& rnd = common::GlobalRandom(); std::vector& gpair_ref = *gpair; #if XGBOOST_CUSTOMIZE_GLOBAL_PRNG std::bernoulli_distribution coin_flip(param_.subsample); size_t used = 0, unused = 0; for (size_t i = 0; i < info.num_row_; ++i) { if (!(gpair_ref[i].GetHess() >= 0.0f && coin_flip(rnd)) || gpair_ref[i].GetGrad() == 0.0f) { gpair_ref[i] = GradientPair(0); } } #else const size_t nthread = this->nthread_; uint64_t initial_seed = rnd(); const size_t discard_size = info.num_row_ / nthread; std::bernoulli_distribution coin_flip(param_.subsample); dmlc::OMPException exc; #pragma omp parallel num_threads(nthread) { exc.Run([&]() { const size_t tid = omp_get_thread_num(); const size_t ibegin = tid * discard_size; const size_t iend = (tid == (nthread - 1)) ? info.num_row_ : ibegin + discard_size; RandomReplace::MakeIf([&](size_t i, RandomReplace::EngineT& eng) { return !(gpair_ref[i].GetHess() >= 0.0f && coin_flip(eng)); }, GradientPair(0), initial_seed, ibegin, iend, &gpair_ref); }); } exc.Rethrow(); #endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG } template size_t QuantileHistMaker::Builder::GetNumberOfTrees() { return n_trees_; } template void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, const DMatrix& fmat, const RegTree& tree, std::vector* gpair) { CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) << "max_depth or max_leaves cannot be both 0 (unlimited); " << "at least one should be a positive quantity."; if (param_.grow_policy == TrainParam::kDepthWise) { CHECK(param_.max_depth > 0) << "max_depth cannot be 0 (unlimited) " << "when grow_policy is depthwise."; } builder_monitor_.Start("InitData"); const auto& info = fmat.Info(); { // initialize the row set row_set_collection_.Clear(); // initialize histogram collection uint32_t nbins = gmat.cut.Ptrs().back(); hist_.Init(nbins); hist_local_worker_.Init(nbins); hist_buffer_.Init(nbins); // initialize histogram builder dmlc::OMPException exc; #pragma omp parallel { exc.Run([&]() { this->nthread_ = omp_get_num_threads(); }); } exc.Rethrow(); hist_builder_ = GHistBuilder(this->nthread_, nbins); std::vector& row_indices = *row_set_collection_.Data(); row_indices.resize(info.num_row_); size_t* p_row_indices = row_indices.data(); // mark subsample and build list of member rows if (param_.subsample < 1.0f) { CHECK_EQ(param_.sampling_method, TrainParam::kUniform) << "Only uniform sampling is supported, " << "gradient-based sampling is only support by GPU Hist."; builder_monitor_.Start("InitSampling"); InitSampling(fmat, gpair, &row_indices); builder_monitor_.Stop("InitSampling"); CHECK_EQ(row_indices.size(), info.num_row_); // We should check that the partitioning was done correctly // and each row of the dataset fell into exactly one of the categories } MemStackAllocator buff(this->nthread_); bool* p_buff = buff.Get(); std::fill(p_buff, p_buff + this->nthread_, false); const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_); #pragma omp parallel num_threads(this->nthread_) { exc.Run([&]() { const size_t tid = omp_get_thread_num(); const size_t ibegin = tid * block_size; const size_t iend = std::min(static_cast(ibegin + block_size), static_cast(info.num_row_)); for (size_t i = ibegin; i < iend; ++i) { if ((*gpair)[i].GetHess() < 0.0f) { p_buff[tid] = true; break; } } }); } exc.Rethrow(); bool has_neg_hess = false; for (int32_t tid = 0; tid < this->nthread_; ++tid) { if (p_buff[tid]) { has_neg_hess = true; } } if (has_neg_hess) { size_t j = 0; for (size_t i = 0; i < info.num_row_; ++i) { if ((*gpair)[i].GetHess() >= 0.0f) { p_row_indices[j++] = i; } } row_indices.resize(j); } else { #pragma omp parallel num_threads(this->nthread_) { exc.Run([&]() { const size_t tid = omp_get_thread_num(); const size_t ibegin = tid * block_size; const size_t iend = std::min(static_cast(ibegin + block_size), static_cast(info.num_row_)); for (size_t i = ibegin; i < iend; ++i) { p_row_indices[i] = i; } }); } exc.Rethrow(); } } row_set_collection_.Init(); { /* determine layout of data */ const size_t nrow = info.num_row_; const size_t ncol = info.num_col_; const size_t nnz = info.num_nonzero_; // number of discrete bins for feature 0 const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0]; if (nrow * ncol == nnz) { // dense data with zero-based indexing data_layout_ = DataLayout::kDenseDataZeroBased; } else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) { // dense data with one-based indexing data_layout_ = DataLayout::kDenseDataOneBased; } else { // sparse data data_layout_ = DataLayout::kSparseData; } } // store a pointer to the tree p_last_tree_ = &tree; if (data_layout_ == DataLayout::kDenseDataOneBased) { column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(), param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree, true); } else { column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(), param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree, false); } if (data_layout_ == DataLayout::kDenseDataZeroBased || data_layout_ == DataLayout::kDenseDataOneBased) { /* specialized code for dense data: choose the column that has a least positive number of discrete bins. For dense data (with no missing value), the sum of gradient histogram is equal to snode[nid] */ const std::vector& row_ptr = gmat.cut.Ptrs(); const auto nfeature = static_cast(row_ptr.size() - 1); uint32_t min_nbins_per_feature = 0; for (bst_uint i = 0; i < nfeature; ++i) { const uint32_t nbins = row_ptr[i + 1] - row_ptr[i]; if (nbins > 0) { if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) { min_nbins_per_feature = nbins; fid_least_bins_ = i; } } } CHECK_GT(min_nbins_per_feature, 0U); } { snode_.reserve(256); snode_.clear(); } { if (param_.grow_policy == TrainParam::kLossGuide) { qexpand_loss_guided_.reset(new ExpandQueue(LossGuide)); } else { qexpand_depth_wise_.clear(); } } builder_monitor_.Stop("InitData"); } // if sum of statistics for non-missing values in the node // is equal to sum of statistics for all values: // then - there are no missing values // else - there are missing values template bool QuantileHistMaker::Builder::SplitContainsMissingValues( const GradStats e, const NodeEntry &snode) { if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { return false; } else { return true; } } // nodes_set - set of nodes to be processed in parallel template void QuantileHistMaker::Builder::EvaluateSplits( const std::vector& nodes_set, const GHistIndexMatrix& gmat, const HistCollection& hist, const RegTree& tree) { builder_monitor_.Start("EvaluateSplits"); const size_t n_nodes_in_set = nodes_set.size(); const size_t nthread = std::max(1, this->nthread_); using FeatureSetType = std::shared_ptr>; std::vector features_sets(n_nodes_in_set); best_split_tloc_.resize(nthread * n_nodes_in_set); // Generate feature set for each tree node for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) { const int32_t nid = nodes_set[nid_in_set].nid; features_sets[nid_in_set] = column_sampler_.GetFeatureSet(tree.GetDepth(nid)); for (unsigned tid = 0; tid < nthread; ++tid) { best_split_tloc_[nthread*nid_in_set + tid] = snode_[nid].best; } } // Create 2D space (# of nodes to process x # of features to process) // to process them in parallel const size_t grain_size = std::max(1, features_sets[0]->Size() / nthread); common::BlockedSpace2d space(n_nodes_in_set, [&](size_t nid_in_set) { return features_sets[nid_in_set]->Size(); }, grain_size); auto evaluator = tree_evaluator_.GetEvaluator(); // Start parallel enumeration for all tree nodes in the set and all features common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) { const int32_t nid = nodes_set[nid_in_set].nid; const auto tid = static_cast(omp_get_thread_num()); GHistRowT node_hist = hist[nid]; for (auto idx_in_feature_set = r.begin(); idx_in_feature_set < r.end(); ++idx_in_feature_set) { const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set]; if (interaction_constraints_.Query(nid, fid)) { auto grad_stats = this->EnumerateSplit<+1>( gmat, node_hist, snode_[nid], &best_split_tloc_[nthread * nid_in_set + tid], fid, nid, evaluator); if (SplitContainsMissingValues(grad_stats, snode_[nid])) { this->EnumerateSplit<-1>( gmat, node_hist, snode_[nid], &best_split_tloc_[nthread * nid_in_set + tid], fid, nid, evaluator); } } } }); // Find Best Split across threads for each node in nodes set for (unsigned nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) { const int32_t nid = nodes_set[nid_in_set].nid; for (unsigned tid = 0; tid < nthread; ++tid) { snode_[nid].best.Update(best_split_tloc_[nthread*nid_in_set + tid]); } } builder_monitor_.Stop("EvaluateSplits"); } // split row indexes (rid_span) to 2 parts (left_part, right_part) depending // on comparison of indexes values (idx_span) and split point (split_cond) // Handle dense columns // Analog of std::stable_partition, but in no-inplace manner template inline std::pair PartitionDenseKernel(const common::DenseColumn& column, common::Span rid_span, const int32_t split_cond, common::Span left_part, common::Span right_part) { const int32_t offset = column.GetBaseIdx(); const BinIdxType* idx = column.GetFeatureBinIdxPtr().data(); size_t* p_left_part = left_part.data(); size_t* p_right_part = right_part.data(); size_t nleft_elems = 0; size_t nright_elems = 0; if (any_missing) { for (auto rid : rid_span) { if (column.IsMissing(rid)) { if (default_left) { p_left_part[nleft_elems++] = rid; } else { p_right_part[nright_elems++] = rid; } } else { if ((static_cast(idx[rid]) + offset) <= split_cond) { p_left_part[nleft_elems++] = rid; } else { p_right_part[nright_elems++] = rid; } } } } else { for (auto rid : rid_span) { if ((static_cast(idx[rid]) + offset) <= split_cond) { p_left_part[nleft_elems++] = rid; } else { p_right_part[nright_elems++] = rid; } } } return {nleft_elems, nright_elems}; } // Split row indexes (rid_span) to 2 parts (left_part, right_part) depending // on comparison of indexes values (idx_span) and split point (split_cond). // Handle sparse columns template inline std::pair PartitionSparseKernel( const common::SparseColumn& column, common::Span rid_span, const int32_t split_cond, common::Span left_part, common::Span right_part) { size_t* p_left_part = left_part.data(); size_t* p_right_part = right_part.data(); size_t nleft_elems = 0; size_t nright_elems = 0; const size_t* row_data = column.GetRowData(); const size_t column_size = column.Size(); if (rid_span.size()) { // ensure that rid_span is nonempty range // search first nonzero row with index >= rid_span.front() const size_t* p = std::lower_bound(row_data, row_data + column_size, rid_span.front()); if (p != row_data + column_size && *p <= rid_span.back()) { size_t cursor = p - row_data; for (auto rid : rid_span) { while (cursor < column_size && column.GetRowIdx(cursor) < rid && column.GetRowIdx(cursor) <= rid_span.back()) { ++cursor; } if (cursor < column_size && column.GetRowIdx(cursor) == rid) { if (static_cast(column.GetGlobalBinIdx(cursor)) <= split_cond) { p_left_part[nleft_elems++] = rid; } else { p_right_part[nright_elems++] = rid; } ++cursor; } else { // missing value if (default_left) { p_left_part[nleft_elems++] = rid; } else { p_right_part[nright_elems++] = rid; } } } } else { // all rows in rid_span have missing values if (default_left) { std::copy(rid_span.begin(), rid_span.end(), p_left_part); nleft_elems = rid_span.size(); } else { std::copy(rid_span.begin(), rid_span.end(), p_right_part); nright_elems = rid_span.size(); } } } return {nleft_elems, nright_elems}; } template template void QuantileHistMaker::Builder::PartitionKernel( const size_t node_in_set, const size_t nid, const common::Range1d range, const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree) { const size_t* rid = row_set_collection_[nid].begin; common::Span rid_span(rid + range.begin(), rid + range.end()); common::Span left = partition_builder_.GetLeftBuffer(node_in_set, range.begin(), range.end()); common::Span right = partition_builder_.GetRightBuffer(node_in_set, range.begin(), range.end()); const bst_uint fid = tree[nid].SplitIndex(); const bool default_left = tree[nid].DefaultLeft(); const auto column_ptr = column_matrix.GetColumn(fid); std::pair child_nodes_sizes; if (column_ptr->GetType() == xgboost::common::kDenseColumn) { const common::DenseColumn& column = static_cast& >(*(column_ptr.get())); if (default_left) { if (column_matrix.AnyMissing()) { child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, left, right); } else { child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, left, right); } } else { if (column_matrix.AnyMissing()) { child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, left, right); } else { child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, left, right); } } } else { const common::SparseColumn& column = static_cast& >(*(column_ptr.get())); if (default_left) { child_nodes_sizes = PartitionSparseKernel(column, rid_span, split_cond, left, right); } else { child_nodes_sizes = PartitionSparseKernel(column, rid_span, split_cond, left, right); } } const size_t n_left = child_nodes_sizes.first; const size_t n_right = child_nodes_sizes.second; partition_builder_.SetNLeftElems(node_in_set, range.begin(), range.end(), n_left); partition_builder_.SetNRightElems(node_in_set, range.begin(), range.end(), n_right); } template void QuantileHistMaker::Builder::FindSplitConditions( const std::vector& nodes, const RegTree& tree, const GHistIndexMatrix& gmat, std::vector* split_conditions) { const size_t n_nodes = nodes.size(); split_conditions->resize(n_nodes); for (size_t i = 0; i < nodes.size(); ++i) { const int32_t nid = nodes[i].nid; const bst_uint fid = tree[nid].SplitIndex(); const bst_float split_pt = tree[nid].SplitCond(); const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; int32_t split_cond = -1; // convert floating-point split_pt into corresponding bin_id // split_cond = -1 indicates that split_pt is less than all known cut points CHECK_LT(upper_bound, static_cast(std::numeric_limits::max())); for (uint32_t bound = lower_bound; bound < upper_bound; ++bound) { if (split_pt == gmat.cut.Values()[bound]) { split_cond = static_cast(bound); } } (*split_conditions)[i] = split_cond; } } template void QuantileHistMaker::Builder::AddSplitsToRowSet( const std::vector& nodes, RegTree* p_tree) { const size_t n_nodes = nodes.size(); for (unsigned int i = 0; i < n_nodes; ++i) { const int32_t nid = nodes[i].nid; const size_t n_left = partition_builder_.GetNLeftElems(i); const size_t n_right = partition_builder_.GetNRightElems(i); CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild()); row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(), n_left, n_right); } } template void QuantileHistMaker::Builder::ApplySplit(const std::vector nodes, const GHistIndexMatrix& gmat, const ColumnMatrix& column_matrix, const HistCollection& hist, RegTree* p_tree) { builder_monitor_.Start("ApplySplit"); // 1. Find split condition for each split const size_t n_nodes = nodes.size(); std::vector split_conditions; FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); // 2.1 Create a blocked space of size SUM(samples in each node) common::BlockedSpace2d space(n_nodes, [&](size_t node_in_set) { int32_t nid = nodes[node_in_set].nid; return row_set_collection_[nid].Size(); }, kPartitionBlockSize); // 2.2 Initialize the partition builder // allocate buffers for storage intermediate results by each thread partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { const int32_t nid = nodes[node_in_set].nid; const size_t size = row_set_collection_[nid].Size(); const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); return n_tasks; }); // 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node // Store results in intermediate buffers from partition_builder_ common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) { size_t begin = r.begin(); const int32_t nid = nodes[node_in_set].nid; const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin); partition_builder_.AllocateForTask(task_id); switch (column_matrix.GetTypeSize()) { case common::kUint8BinsTypeSize: PartitionKernel(node_in_set, nid, r, split_conditions[node_in_set], column_matrix, *p_tree); break; case common::kUint16BinsTypeSize: PartitionKernel(node_in_set, nid, r, split_conditions[node_in_set], column_matrix, *p_tree); break; case common::kUint32BinsTypeSize: PartitionKernel(node_in_set, nid, r, split_conditions[node_in_set], column_matrix, *p_tree); break; default: CHECK(false); // no default behavior } }); // 3. Compute offsets to copy blocks of row-indexes // from partition_builder_ to row_set_collection_ partition_builder_.CalculateRowOffsets(); // 4. Copy elements from partition_builder_ to row_set_collection_ back // with updated row-indexes for each tree-node common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) { const int32_t nid = nodes[node_in_set].nid; partition_builder_.MergeToArray(node_in_set, r.begin(), const_cast(row_set_collection_[nid].begin)); }); // 5. Add info about splits into row_set_collection_ AddSplitsToRowSet(nodes, p_tree); builder_monitor_.Stop("ApplySplit"); } template void QuantileHistMaker::Builder::InitNewNode(int nid, const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, const RegTree& tree) { builder_monitor_.Start("InitNewNode"); { snode_.resize(tree.param.num_nodes, NodeEntry(param_)); } { GHistRowT hist = hist_[nid]; GradientPairT grad_stat; if (tree[nid].IsRoot()) { if (data_layout_ == DataLayout::kDenseDataZeroBased || data_layout_ == DataLayout::kDenseDataOneBased) { const std::vector& row_ptr = gmat.cut.Ptrs(); const uint32_t ibegin = row_ptr[fid_least_bins_]; const uint32_t iend = row_ptr[fid_least_bins_ + 1]; auto begin = hist.data(); for (uint32_t i = ibegin; i < iend; ++i) { const GradientPairT et = begin[i]; grad_stat.Add(et.GetGrad(), et.GetHess()); } } else { const RowSetCollection::Elem e = row_set_collection_[nid]; for (const size_t* it = e.begin; it < e.end; ++it) { grad_stat.Add(gpair[*it].GetGrad(), gpair[*it].GetHess()); } } histred_.Allreduce(&grad_stat, 1); snode_[nid].stats = tree::GradStats(grad_stat.GetGrad(), grad_stat.GetHess()); } else { int parent_id = tree[nid].Parent(); if (tree[nid].IsLeftChild()) { snode_[nid].stats = snode_[parent_id].best.left_sum; } else { snode_[nid].stats = snode_[parent_id].best.right_sum; } } } // calculating the weights { auto evaluator = tree_evaluator_.GetEvaluator(); bst_uint parentid = tree[nid].Parent(); snode_[nid].weight = static_cast( evaluator.CalcWeight(parentid, param_, GradStats{snode_[nid].stats})); snode_[nid].root_gain = static_cast( evaluator.CalcGain(parentid, param_, GradStats{snode_[nid].stats})); } builder_monitor_.Stop("InitNewNode"); } // Enumerate the split values of specific feature. // Returns the sum of gradients corresponding to the data points that contains a non-missing value // for the particular feature fid. template template GradStats QuantileHistMaker::Builder::EnumerateSplit( const GHistIndexMatrix &gmat, const GHistRowT &hist, const NodeEntry &snode, SplitEntry *p_best, bst_uint fid, bst_uint nodeID, TreeEvaluator::SplitEvaluator const &evaluator) const { CHECK(d_step == +1 || d_step == -1); // aliases const std::vector& cut_ptr = gmat.cut.Ptrs(); const std::vector& cut_val = gmat.cut.Values(); // statistics on both sides of split GradStats c; GradStats e; // best split so far SplitEntry best; // bin boundaries CHECK_LE(cut_ptr[fid], static_cast(std::numeric_limits::max())); CHECK_LE(cut_ptr[fid + 1], static_cast(std::numeric_limits::max())); // imin: index (offset) of the minimum value for feature fid // need this for backward enumeration const auto imin = static_cast(cut_ptr[fid]); // ibegin, iend: smallest/largest cut points for feature fid // use int to allow for value -1 int32_t ibegin, iend; if (d_step > 0) { ibegin = static_cast(cut_ptr[fid]); iend = static_cast(cut_ptr[fid + 1]); } else { ibegin = static_cast(cut_ptr[fid + 1]) - 1; iend = static_cast(cut_ptr[fid]) - 1; } for (int32_t i = ibegin; i != iend; i += d_step) { // start working // try to find a split e.Add(hist[i].GetGrad(), hist[i].GetHess()); if (e.GetHess() >= param_.min_child_weight) { c.SetSubstract(snode.stats, e); if (c.GetHess() >= param_.min_child_weight) { bst_float loss_chg; bst_float split_pt; if (d_step > 0) { // forward enumeration: split at right bound of each bin loss_chg = static_cast( evaluator.CalcSplitGain(param_, nodeID, fid, GradStats{e}, GradStats{c}) - snode.root_gain); split_pt = cut_val[i]; best.Update(loss_chg, fid, split_pt, d_step == -1, e, c); } else { // backward enumeration: split at left bound of each bin loss_chg = static_cast( evaluator.CalcSplitGain(param_, nodeID, fid, GradStats{c}, GradStats{e}) - snode.root_gain); if (i == imin) { // for leftmost bin, left bound is the smallest feature value split_pt = gmat.cut.MinValues()[fid]; } else { split_pt = cut_val[i - 1]; } best.Update(loss_chg, fid, split_pt, d_step == -1, c, e); } } } } p_best->Update(best); return e; } template struct QuantileHistMaker::Builder; template struct QuantileHistMaker::Builder; template void QuantileHistMaker::Builder::PartitionKernel( const size_t node_in_set, const size_t nid, common::Range1d range, const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree); template void QuantileHistMaker::Builder::PartitionKernel( const size_t node_in_set, const size_t nid, common::Range1d range, const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree); template void QuantileHistMaker::Builder::PartitionKernel( const size_t node_in_set, const size_t nid, common::Range1d range, const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree); template void QuantileHistMaker::Builder::PartitionKernel( const size_t node_in_set, const size_t nid, common::Range1d range, const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree); template void QuantileHistMaker::Builder::PartitionKernel( const size_t node_in_set, const size_t nid, common::Range1d range, const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree); template void QuantileHistMaker::Builder::PartitionKernel( const size_t node_in_set, const size_t nid, common::Range1d range, const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree); XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker") .describe("(Deprecated, use grow_quantile_histmaker instead.)" " Grow tree using quantized histogram.") .set_body( []() { LOG(WARNING) << "grow_fast_histmaker is deprecated, " << "use grow_quantile_histmaker instead."; return new QuantileHistMaker(); }); XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") .set_body( []() { return new QuantileHistMaker(); }); } // namespace tree } // namespace xgboost