Rewrite approx (#7214)

This PR rewrites the approx tree method to use codebase from hist for better performance and code sharing.

The rewrite has many benefits:
- Support for both `max_leaves` and `max_depth`.
- Support for `grow_policy`.
- Support for mono constraint.
- Support for feature weights.
- Support for easier bin configuration (`max_bin`).
- Support for categorical data.
- Faster performance for most of the datasets. (many times faster)
- Support for prediction cache.
- Significantly better performance for external memory.
- Unites the code base between approx and hist.
This commit is contained in:
Jiaming Yuan
2022-01-10 21:15:05 +08:00
committed by GitHub
parent ed95e77752
commit 001503186c
22 changed files with 635 additions and 264 deletions

View File

@@ -38,7 +38,7 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
int grow_policy;
uint32_t max_cat_to_onehot{1};
uint32_t max_cat_to_onehot{4};
//----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child

View File

@@ -973,6 +973,7 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
}
size_t size = categories.size() - begin;
categories_sizes.emplace_back(static_cast<Integer::Int>(size));
CHECK_NE(size, 0);
}
}

View File

@@ -35,6 +35,7 @@ DMLC_REGISTRY_LINK_TAG(updater_refresh);
DMLC_REGISTRY_LINK_TAG(updater_prune);
DMLC_REGISTRY_LINK_TAG(updater_quantile_hist);
DMLC_REGISTRY_LINK_TAG(updater_histmaker);
DMLC_REGISTRY_LINK_TAG(updater_approx);
DMLC_REGISTRY_LINK_TAG(updater_sync);
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);

369
src/tree/updater_approx.cc Normal file
View File

@@ -0,0 +1,369 @@
/*!
* Copyright 2021 XGBoost contributors
*
* \brief Implementation for the approx tree method.
*/
#include "updater_approx.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "../common/random.h"
#include "../data/gradient_index.h"
#include "constraints.h"
#include "driver.h"
#include "hist/evaluate_splits.h"
#include "hist/histogram.h"
#include "hist/param.h"
#include "param.h"
#include "xgboost/base.h"
#include "xgboost/json.h"
#include "xgboost/tree_updater.h"
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_approx);
template <typename GradientSumT>
class GloablApproxBuilder {
protected:
TrainParam param_;
std::shared_ptr<common::ColumnSampler> col_sampler_;
HistEvaluator<GradientSumT, CPUExpandEntry> evaluator_;
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram_builder_;
GenericParameter const *ctx_;
std::vector<ApproxRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
RegTree *p_last_tree_{nullptr};
common::Monitor *monitor_;
size_t n_batches_{0};
// Cache for histogram cuts.
common::HistogramCuts feature_values_;
public:
void InitData(DMatrix *p_fmat, common::Span<float> hess) {
monitor_->Start(__func__);
n_batches_ = 0;
int32_t n_total_bins = 0;
partitioner_.clear();
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, param_.max_bin, hess, true})) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
feature_values_ = page.cut;
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(page.Size(), page.base_rowid);
n_batches_++;
}
histogram_builder_.Reset(n_total_bins,
BatchParam{GenericParameter::kCpuId, param_.max_bin, hess},
ctx_->Threads(), n_batches_, rabit::IsDistributed());
monitor_->Stop(__func__);
}
CPUExpandEntry InitRoot(DMatrix *p_fmat, std::vector<GradientPair> const &gpair,
common::Span<float> hess, RegTree *p_tree) {
monitor_->Start(__func__);
CPUExpandEntry best;
best.nid = RegTree::kRoot;
best.depth = 0;
GradStats root_sum;
for (auto const &g : gpair) {
root_sum.Add(g);
}
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2);
std::vector<CPUExpandEntry> nodes{best};
size_t i = 0;
auto space = this->ConstructHistSpace(nodes);
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
{}, gpair);
i++;
}
auto weight = evaluator_.InitRoot(root_sum);
p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess();
p_tree->Stat(RegTree::kRoot).base_weight = weight;
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
auto const &histograms = histogram_builder_.Histogram();
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes);
monitor_->Stop(__func__);
return nodes.front();
}
void UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) {
monitor_->Start(__func__);
// Caching prediction seems redundant for approx tree method, as sketching takes up
// majority of training time.
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
CHECK(p_last_tree_);
size_t n_nodes = p_last_tree_->GetNodes().size();
auto evaluator = evaluator_.Evaluator();
auto const &tree = *p_last_tree_;
auto const &snode = evaluator_.Stats();
for (auto &part : partitioner_) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) {
if (tree[nidx].IsLeaf()) {
const auto rowset = part[nidx];
auto const &stats = snode.at(nidx);
auto leaf_value =
evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate;
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
}
monitor_->Stop(__func__);
}
// Construct a work space for building histogram. Eventually we should move this
// function into histogram builder once hist tree method supports external memory.
common::BlockedSpace2d ConstructHistSpace(
std::vector<CPUExpandEntry> const &nodes_to_build) const {
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
for (auto const &partition : partitioner_) {
size_t k = 0;
for (auto node : nodes_to_build) {
auto n_rows_in_node = partition.Partitions()[node.nid].Size();
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
k++;
}
}
common::BlockedSpace2d space{nodes_to_build.size(),
[&](size_t nidx_in_set) { return partition_size[nidx_in_set]; },
256};
return space;
}
void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree,
std::vector<CPUExpandEntry> const &valid_candidates,
std::vector<GradientPair> const &gpair, common::Span<float> hess) {
monitor_->Start(__func__);
std::vector<CPUExpandEntry> nodes_to_build;
std::vector<CPUExpandEntry> nodes_to_sub;
for (auto const &c : valid_candidates) {
auto left_nidx = (*p_tree)[c.nid].LeftChild();
auto right_nidx = (*p_tree)[c.nid].RightChild();
auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess();
auto build_nidx = left_nidx;
auto subtract_nidx = right_nidx;
if (fewer_right) {
std::swap(build_nidx, subtract_nidx);
}
nodes_to_build.push_back(CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}});
nodes_to_sub.push_back(CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}});
}
size_t i = 0;
auto space = this->ConstructHistSpace(nodes_to_build);
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
nodes_to_build, nodes_to_sub, gpair);
i++;
}
monitor_->Stop(__func__);
}
public:
explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx,
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
common::Monitor *monitor)
: param_{std::move(param)},
col_sampler_{std::move(column_sampler)},
evaluator_{param_, info, ctx->Threads(), col_sampler_, task},
ctx_{ctx},
monitor_{monitor} {}
void UpdateTree(RegTree *p_tree, std::vector<GradientPair> const &gpair, common::Span<float> hess,
DMatrix *p_fmat) {
p_last_tree_ = p_tree;
this->InitData(p_fmat, hess);
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
auto &tree = *p_tree;
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
bst_node_t num_leaves = 1;
auto expand_set = driver.Pop();
while (!expand_set.empty()) {
// candidates that can be further splited.
std::vector<CPUExpandEntry> valid_candidates;
// candidates that can be applied.
std::vector<CPUExpandEntry> applied;
for (auto const &candidate : expand_set) {
if (!candidate.IsValid(param_, num_leaves)) {
continue;
}
evaluator_.ApplyTreeSplit(candidate, p_tree);
applied.push_back(candidate);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) {
valid_candidates.emplace_back(candidate);
}
}
monitor_->Start("UpdatePosition");
size_t i = 0;
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree);
i++;
}
monitor_->Stop("UpdatePosition");
std::vector<CPUExpandEntry> best_splits;
if (!valid_candidates.empty()) {
this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair, hess);
for (auto const &candidate : valid_candidates) {
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}};
CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}};
best_splits.push_back(l_best);
best_splits.push_back(r_best);
}
auto const &histograms = histogram_builder_.Histogram();
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
monitor_->Start("EvaluateSplits");
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits);
monitor_->Stop("EvaluateSplits");
}
driver.Push(best_splits.begin(), best_splits.end());
expand_set = driver.Pop();
}
}
};
/**
* \brief Implementation for the approx tree method. It constructs quantile for every
* iteration.
*/
class GlobalApproxUpdater : public TreeUpdater {
TrainParam param_;
common::Monitor monitor_;
CPUHistMakerTrainParam hist_param_;
// specializations for different histogram precision.
std::unique_ptr<GloablApproxBuilder<float>> f32_impl_;
std::unique_ptr<GloablApproxBuilder<double>> f64_impl_;
// pointer to the last DMatrix, used for update prediction cache.
DMatrix *cached_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
ObjInfo task_;
public:
explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); }
void Configure(const Args &args) override {
param_.UpdateAllowUnknown(args);
hist_param_.UpdateAllowUnknown(args);
}
void LoadConfig(Json const &in) override {
auto const &config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_);
FromJson(config.at("hist_param"), &this->hist_param_);
}
void SaveConfig(Json *p_out) const override {
auto &out = *p_out;
out["train_param"] = ToJson(param_);
out["hist_param"] = ToJson(hist_param_);
}
void InitData(TrainParam const &param, HostDeviceVector<GradientPair> *gpair,
std::vector<GradientPair> *sampled) {
auto const &h_gpair = gpair->HostVector();
sampled->resize(h_gpair.size());
std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin());
auto &rnd = common::GlobalRandom();
if (param.subsample != 1.0) {
CHECK(param.sampling_method != TrainParam::kGradientBased)
<< "Gradient based sampling is not supported for approx tree method.";
std::bernoulli_distribution coin_flip(param.subsample);
std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) {
if (coin_flip(rnd)) {
return g;
} else {
return GradientPair{};
}
});
}
}
char const *Name() const override { return "grow_histmaker"; }
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *m,
const std::vector<RegTree *> &trees) override {
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
if (hist_param_.single_precision_histogram) {
f32_impl_ = std::make_unique<GloablApproxBuilder<float>>(param_, m->Info(), tparam_,
column_sampler_, task_, &monitor_);
} else {
f64_impl_ = std::make_unique<GloablApproxBuilder<double>>(param_, m->Info(), tparam_,
column_sampler_, task_, &monitor_);
}
std::vector<GradientPair> h_gpair;
InitData(param_, gpair, &h_gpair);
// Obtain the hessian values for weighted sketching
std::vector<float> hess(h_gpair.size());
std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(),
[](auto g) { return g.GetHess(); });
cached_ = m;
for (auto p_tree : trees) {
if (hist_param_.single_precision_histogram) {
this->f32_impl_->UpdateTree(p_tree, h_gpair, hess, m);
} else {
this->f64_impl_->UpdateTree(p_tree, h_gpair, hess, m);
}
}
param_.learning_rate = lr;
}
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
if (data != cached_ || (!this->f32_impl_ && !this->f64_impl_)) {
return false;
}
if (hist_param_.single_precision_histogram) {
this->f32_impl_->UpdatePredictionCache(data, out_preds);
} else {
this->f64_impl_->UpdatePredictionCache(data, out_preds);
}
return true;
}
};
DMLC_REGISTRY_FILE_TAG(grow_histmaker);
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
.describe(
"Tree constructor that uses approximate histogram construction "
"for each node.")
.set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); });
} // namespace tree
} // namespace xgboost

View File

@@ -641,126 +641,10 @@ class CQHistMaker: public HistMaker {
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs_;
};
// global proposal
class GlobalProposalHistMaker: public CQHistMaker {
public:
char const* Name() const override {
return "grow_histmaker";
}
protected:
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_feature_t> &fset,
const RegTree &tree) override {
if (this->qexpand_.size() == 1) {
cached_rptr_.clear();
cached_cut_.clear();
}
if (cached_rptr_.size() == 0) {
CHECK_EQ(this->qexpand_.size(), 1U);
CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree);
cached_rptr_ = this->wspace_.rptr;
cached_cut_ = this->wspace_.cut;
} else {
this->wspace_.cut.clear();
this->wspace_.rptr.clear();
this->wspace_.rptr.push_back(0);
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
this->wspace_.rptr.push_back(
this->wspace_.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
}
this->wspace_.cut.insert(this->wspace_.cut.end(), cached_cut_.begin(), cached_cut_.end());
}
CHECK_EQ(this->wspace_.rptr.size(),
(fset.size() + 1) * this->qexpand_.size() + 1);
CHECK_EQ(this->wspace_.rptr.back(), this->wspace_.cut.size());
}
}
// code to create histogram
void CreateHist(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_feature_t> &fset,
const RegTree &tree) override {
const MetaInfo &info = p_fmat->Info();
// fill in reverse map
this->feat2workindex_.resize(tree.param.num_feature);
this->work_set_ = fset;
std::fill(this->feat2workindex_.begin(), this->feat2workindex_.end(), -1);
for (size_t i = 0; i < fset.size(); ++i) {
this->feat2workindex_[fset[i]] = static_cast<int>(i);
}
// start to work
this->wspace_.Configure(1);
// to gain speedup in recovery
{
this->thread_hist_.resize(omp_get_max_threads());
// TWOPASS: use the real set + split set in the column iteration.
this->SetDefaultPostion(p_fmat, tree);
this->work_set_.insert(this->work_set_.end(), this->fsplit_set_.begin(),
this->fsplit_set_.end());
XGBOOST_PARALLEL_SORT(this->work_set_.begin(), this->work_set_.end(),
std::less<>{});
this->work_set_.resize(
std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());
// start accumulating statistics
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
auto page = batch.GetView();
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
dmlc::OMPException exc;
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
exc.Run([&]() {
int fid = this->work_set_[i];
int offset = this->feat2workindex_[fid];
if (offset >= 0) {
this->UpdateHistCol(gpair, page[fid], info, tree,
fset, offset,
&this->thread_hist_[omp_get_thread_num()]);
}
});
}
exc.Rethrow();
}
// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&(this->thread_stats_), &(this->node_stats_));
for (const int nid : this->qexpand_) {
const int wid = this->node2workindex_[nid];
this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
.data[0] = this->node_stats_[nid];
}
}
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
this->wspace_.hset[0].data.size());
}
// cached unit pointer
std::vector<unsigned> cached_rptr_;
// cached cut value.
std::vector<bst_float> cached_cut_;
};
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
.describe("Tree constructor that uses approximate histogram construction.")
.set_body([](ObjInfo) {
return new CQHistMaker();
});
// The updater for approx tree method.
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
.describe("Tree constructor that uses approximate global of histogram construction.")
.set_body([](ObjInfo) {
return new GlobalProposalHistMaker();
});
} // namespace tree
} // namespace xgboost