/*! * Copyright 2021 XGBoost contributors * * \brief Implementation for the approx tree method. */ #include "updater_approx.h" #include #include #include #include "../common/random.h" #include "../data/gradient_index.h" #include "constraints.h" #include "driver.h" #include "hist/evaluate_splits.h" #include "hist/histogram.h" #include "hist/param.h" #include "param.h" #include "xgboost/base.h" #include "xgboost/json.h" #include "xgboost/tree_updater.h" namespace xgboost { namespace tree { DMLC_REGISTRY_FILE_TAG(updater_approx); template class GloablApproxBuilder { protected: TrainParam param_; std::shared_ptr col_sampler_; HistEvaluator evaluator_; HistogramBuilder histogram_builder_; GenericParameter const *ctx_; std::vector partitioner_; // Pointer to last updated tree, used for update prediction cache. RegTree *p_last_tree_{nullptr}; common::Monitor *monitor_; size_t n_batches_{0}; // Cache for histogram cuts. common::HistogramCuts feature_values_; public: void InitData(DMatrix *p_fmat, common::Span hess) { monitor_->Start(__func__); n_batches_ = 0; int32_t n_total_bins = 0; partitioner_.clear(); // Generating the GHistIndexMatrix is quite slow, is there a way to speed it up? for (auto const &page : p_fmat->GetBatches( {GenericParameter::kCpuId, param_.max_bin, hess, true})) { if (n_total_bins == 0) { n_total_bins = page.cut.TotalBins(); feature_values_ = page.cut; } else { CHECK_EQ(n_total_bins, page.cut.TotalBins()); } partitioner_.emplace_back(page.Size(), page.base_rowid); n_batches_++; } histogram_builder_.Reset(n_total_bins, BatchParam{GenericParameter::kCpuId, param_.max_bin, hess}, ctx_->Threads(), n_batches_, rabit::IsDistributed()); monitor_->Stop(__func__); } CPUExpandEntry InitRoot(DMatrix *p_fmat, std::vector const &gpair, common::Span hess, RegTree *p_tree) { monitor_->Start(__func__); CPUExpandEntry best; best.nid = RegTree::kRoot; best.depth = 0; GradStats root_sum; for (auto const &g : gpair) { root_sum.Add(g); } rabit::Allreduce(reinterpret_cast(&root_sum), 2); std::vector nodes{best}; size_t i = 0; auto space = this->ConstructHistSpace(nodes); for (auto const &page : p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes, {}, gpair); i++; } auto weight = evaluator_.InitRoot(root_sum); p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess(); p_tree->Stat(RegTree::kRoot).base_weight = weight; (*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight); auto const &histograms = histogram_builder_.Histogram(); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes); monitor_->Stop(__func__); return nodes.front(); } void UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) { monitor_->Start(__func__); // Caching prediction seems redundant for approx tree method, as sketching takes up // majority of training time. CHECK_EQ(out_preds.Size(), data->Info().num_row_); CHECK(p_last_tree_); size_t n_nodes = p_last_tree_->GetNodes().size(); auto evaluator = evaluator_.Evaluator(); auto const &tree = *p_last_tree_; auto const &snode = evaluator_.Stats(); for (auto &part : partitioner_) { CHECK_EQ(part.Size(), n_nodes); common::BlockedSpace2d space( part.Size(), [&](size_t node) { return part[node].Size(); }, 1024); common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) { if (tree[nidx].IsLeaf()) { const auto rowset = part[nidx]; auto const &stats = snode.at(nidx); auto leaf_value = evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate; for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { out_preds(*it) += leaf_value; } } }); } monitor_->Stop(__func__); } // Construct a work space for building histogram. Eventually we should move this // function into histogram builder once hist tree method supports external memory. common::BlockedSpace2d ConstructHistSpace( std::vector const &nodes_to_build) const { std::vector partition_size(nodes_to_build.size(), 0); for (auto const &partition : partitioner_) { size_t k = 0; for (auto node : nodes_to_build) { auto n_rows_in_node = partition.Partitions()[node.nid].Size(); partition_size[k] = std::max(partition_size[k], n_rows_in_node); k++; } } common::BlockedSpace2d space{nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256}; return space; } void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, std::vector const &valid_candidates, std::vector const &gpair, common::Span hess) { monitor_->Start(__func__); std::vector nodes_to_build; std::vector nodes_to_sub; for (auto const &c : valid_candidates) { auto left_nidx = (*p_tree)[c.nid].LeftChild(); auto right_nidx = (*p_tree)[c.nid].RightChild(); auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); auto build_nidx = left_nidx; auto subtract_nidx = right_nidx; if (fewer_right) { std::swap(build_nidx, subtract_nidx); } nodes_to_build.push_back(CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}); nodes_to_sub.push_back(CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}); } size_t i = 0; auto space = this->ConstructHistSpace(nodes_to_build); for (auto const &page : p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes_to_build, nodes_to_sub, gpair); i++; } monitor_->Stop(__func__); } public: explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx, std::shared_ptr column_sampler, ObjInfo task, common::Monitor *monitor) : param_{std::move(param)}, col_sampler_{std::move(column_sampler)}, evaluator_{param_, info, ctx->Threads(), col_sampler_, task}, ctx_{ctx}, monitor_{monitor} {} void UpdateTree(RegTree *p_tree, std::vector const &gpair, common::Span hess, DMatrix *p_fmat) { p_last_tree_ = p_tree; this->InitData(p_fmat, hess); Driver driver(static_cast(param_.grow_policy)); auto &tree = *p_tree; driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); bst_node_t num_leaves = 1; auto expand_set = driver.Pop(); while (!expand_set.empty()) { // candidates that can be further splited. std::vector valid_candidates; // candidates that can be applied. std::vector applied; for (auto const &candidate : expand_set) { if (!candidate.IsValid(param_, num_leaves)) { continue; } evaluator_.ApplyTreeSplit(candidate, p_tree); applied.push_back(candidate); num_leaves++; int left_child_nidx = tree[candidate.nid].LeftChild(); if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) { valid_candidates.emplace_back(candidate); } } monitor_->Start("UpdatePosition"); size_t i = 0; for (auto const &page : p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree); i++; } monitor_->Stop("UpdatePosition"); std::vector best_splits; if (!valid_candidates.empty()) { this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair, hess); for (auto const &candidate : valid_candidates) { int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}}; CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}}; best_splits.push_back(l_best); best_splits.push_back(r_best); } auto const &histograms = histogram_builder_.Histogram(); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); monitor_->Start("EvaluateSplits"); evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits); monitor_->Stop("EvaluateSplits"); } driver.Push(best_splits.begin(), best_splits.end()); expand_set = driver.Pop(); } } }; /** * \brief Implementation for the approx tree method. It constructs quantile for every * iteration. */ class GlobalApproxUpdater : public TreeUpdater { TrainParam param_; common::Monitor monitor_; CPUHistMakerTrainParam hist_param_; // specializations for different histogram precision. std::unique_ptr> f32_impl_; std::unique_ptr> f64_impl_; // pointer to the last DMatrix, used for update prediction cache. DMatrix *cached_{nullptr}; std::shared_ptr column_sampler_ = std::make_shared(); ObjInfo task_; public: explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); } void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); hist_param_.UpdateAllowUnknown(args); } void LoadConfig(Json const &in) override { auto const &config = get(in); FromJson(config.at("train_param"), &this->param_); FromJson(config.at("hist_param"), &this->hist_param_); } void SaveConfig(Json *p_out) const override { auto &out = *p_out; out["train_param"] = ToJson(param_); out["hist_param"] = ToJson(hist_param_); } void InitData(TrainParam const ¶m, HostDeviceVector *gpair, std::vector *sampled) { auto const &h_gpair = gpair->HostVector(); sampled->resize(h_gpair.size()); std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); auto &rnd = common::GlobalRandom(); if (param.subsample != 1.0) { CHECK(param.sampling_method != TrainParam::kGradientBased) << "Gradient based sampling is not supported for approx tree method."; std::bernoulli_distribution coin_flip(param.subsample); std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) { if (coin_flip(rnd)) { return g; } else { return GradientPair{}; } }); } } char const *Name() const override { return "grow_histmaker"; } void Update(HostDeviceVector *gpair, DMatrix *m, const std::vector &trees) override { float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); if (hist_param_.single_precision_histogram) { f32_impl_ = std::make_unique>(param_, m->Info(), tparam_, column_sampler_, task_, &monitor_); } else { f64_impl_ = std::make_unique>(param_, m->Info(), tparam_, column_sampler_, task_, &monitor_); } std::vector h_gpair; InitData(param_, gpair, &h_gpair); // Obtain the hessian values for weighted sketching std::vector hess(h_gpair.size()); std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(), [](auto g) { return g.GetHess(); }); cached_ = m; for (auto p_tree : trees) { if (hist_param_.single_precision_histogram) { this->f32_impl_->UpdateTree(p_tree, h_gpair, hess, m); } else { this->f64_impl_->UpdateTree(p_tree, h_gpair, hess, m); } } param_.learning_rate = lr; } bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) override { if (data != cached_ || (!this->f32_impl_ && !this->f64_impl_)) { return false; } if (hist_param_.single_precision_histogram) { this->f32_impl_->UpdatePredictionCache(data, out_preds); } else { this->f64_impl_->UpdatePredictionCache(data, out_preds); } return true; } }; DMLC_REGISTRY_FILE_TAG(grow_histmaker); XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker") .describe( "Tree constructor that uses approximate histogram construction " "for each node.") .set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); }); } // namespace tree } // namespace xgboost