Rewrite approx (#7214)

This PR rewrites the approx tree method to use codebase from hist for better performance and code sharing.

The rewrite has many benefits:
- Support for both `max_leaves` and `max_depth`.
- Support for `grow_policy`.
- Support for mono constraint.
- Support for feature weights.
- Support for easier bin configuration (`max_bin`).
- Support for categorical data.
- Faster performance for most of the datasets. (many times faster)
- Support for prediction cache.
- Significantly better performance for external memory.
- Unites the code base between approx and hist.
This commit is contained in:
Jiaming Yuan 2022-01-10 21:15:05 +08:00 committed by GitHub
parent ed95e77752
commit 001503186c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 635 additions and 264 deletions

View File

@ -57,6 +57,7 @@
#include "../src/tree/updater_refresh.cc"
#include "../src/tree/updater_sync.cc"
#include "../src/tree/updater_histmaker.cc"
#include "../src/tree/updater_approx.cc"
#include "../src/tree/constraints.cc"
// linear

View File

@ -3,7 +3,8 @@ Getting started with categorical data
=====================================
Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has
experimental support for one-hot encoding based tree split.
experimental support for one-hot encoding based tree split, and in 1.6 `approx` supported
was added.
In before, users need to run an encoder themselves before passing the data into XGBoost,
which creates a sparse matrix and potentially increase memory usage. This demo showcases

View File

@ -154,7 +154,7 @@ Parameters for Tree Booster
* ``sketch_eps`` [default=0.03]
- Only used for ``tree_method=approx``.
- Only used for ``updater=grow_local_histmaker``.
- This roughly translates into ``O(1 / sketch_eps)`` number of bins.
Compared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy.
- Usually user does not have to tune this.
@ -238,13 +238,27 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See :doc:`/tutorials/feature_interaction_constraint` for more information.
Additional parameters for ``hist`` and ``gpu_hist`` tree method
================================================================
Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
===========================================================================
* ``single_precision_histogram``, [default= ``false``]
- Use single precision to build histograms instead of double precision.
Additional parameters for ``approx`` tree method
================================================
* ``max_cat_to_onehot``
.. versionadded:: 1.6
.. note:: The support for this parameter is experimental.
- A threshold for deciding whether XGBoost should use one-hot encoding based split for
categorical data. When number of categories is lesser than the threshold then one-hot
encoding is chosen, otherwise the categories will be partitioned into children nodes.
Only relevant for regression and binary classification with `approx` tree method.
Additional parameters for Dart Booster (``booster=dart``)
=========================================================

View File

@ -53,7 +53,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"max_bin" -> 16,
"max_bin" -> 64,
"tree_method" -> treeMethod)
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)

View File

@ -267,6 +267,16 @@ __model_doc = f'''
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)]
max_cat_to_onehot : bool
.. versionadded:: 1.6.0
A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold then
one-hot encoding is chosen, otherwise the categories will be partitioned into
children nodes. Only relevant for regression and binary classification and
`approx` tree method.
kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters
can be found :doc:`here </parameter>`.
@ -483,6 +493,7 @@ class XGBModel(XGBModelBase):
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
callbacks: Optional[List[TrainingCallback]] = None,
max_cat_to_onehot: Optional[int] = None,
**kwargs: Any
) -> None:
if not SKLEARN_INSTALLED:
@ -522,6 +533,7 @@ class XGBModel(XGBModelBase):
self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds
self.callbacks = callbacks
self.max_cat_to_onehot = max_cat_to_onehot
if kwargs:
self.kwargs = kwargs
@ -800,8 +812,8 @@ class XGBModel(XGBModelBase):
_duplicated("callbacks")
callbacks = self.callbacks if self.callbacks is not None else callbacks
# lastly check categorical data support.
if self.enable_categorical and params.get("tree_method", None) != "gpu_hist":
tree_method = params.get("tree_method", None)
if self.enable_categorical and tree_method not in ("gpu_hist", "approx"):
raise ValueError(
"Experimental support for categorical data is not implemented for"
" current tree method yet."
@ -876,8 +888,7 @@ class XGBModel(XGBModelBase):
feature_weights :
Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
`exact` tree methods.
otherwise a `ValueError` is thrown.
callbacks :
.. deprecated: 1.6.0
@ -1750,8 +1761,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
feature_weights :
Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
`exact` tree methods.
otherwise a `ValueError` is thrown.
callbacks :
.. deprecated: 1.6.0

View File

@ -130,8 +130,7 @@ class BlockedSpace2d {
template <typename Func>
void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
const size_t num_blocks_in_space = space.Size();
nthreads = std::min(nthreads, omp_get_max_threads());
nthreads = std::max(nthreads, 1);
CHECK_GE(nthreads, 1);
dmlc::OMPException exc;
#pragma omp parallel num_threads(nthreads)
@ -277,9 +276,10 @@ inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
inline int32_t OmpGetNumThreads(int32_t n_threads) {
if (n_threads <= 0) {
n_threads = omp_get_num_procs();
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
}
n_threads = std::min(n_threads, OmpGetThreadLimit());
n_threads = std::max(n_threads, 1);
return n_threads;
}
} // namespace common

View File

@ -168,7 +168,7 @@ void GBTree::ConfigureUpdaters() {
// calling this function.
break;
case TreeMethod::kApprox:
tparam_.updater_seq = "grow_histmaker,prune";
tparam_.updater_seq = "grow_histmaker";
break;
case TreeMethod::kExact:
tparam_.updater_seq = "grow_colmaker,prune";

View File

@ -38,7 +38,7 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
int grow_policy;
uint32_t max_cat_to_onehot{1};
uint32_t max_cat_to_onehot{4};
//----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child

View File

@ -973,6 +973,7 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
}
size_t size = categories.size() - begin;
categories_sizes.emplace_back(static_cast<Integer::Int>(size));
CHECK_NE(size, 0);
}
}

View File

@ -35,6 +35,7 @@ DMLC_REGISTRY_LINK_TAG(updater_refresh);
DMLC_REGISTRY_LINK_TAG(updater_prune);
DMLC_REGISTRY_LINK_TAG(updater_quantile_hist);
DMLC_REGISTRY_LINK_TAG(updater_histmaker);
DMLC_REGISTRY_LINK_TAG(updater_approx);
DMLC_REGISTRY_LINK_TAG(updater_sync);
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);

369
src/tree/updater_approx.cc Normal file
View File

@ -0,0 +1,369 @@
/*!
* Copyright 2021 XGBoost contributors
*
* \brief Implementation for the approx tree method.
*/
#include "updater_approx.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "../common/random.h"
#include "../data/gradient_index.h"
#include "constraints.h"
#include "driver.h"
#include "hist/evaluate_splits.h"
#include "hist/histogram.h"
#include "hist/param.h"
#include "param.h"
#include "xgboost/base.h"
#include "xgboost/json.h"
#include "xgboost/tree_updater.h"
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_approx);
template <typename GradientSumT>
class GloablApproxBuilder {
protected:
TrainParam param_;
std::shared_ptr<common::ColumnSampler> col_sampler_;
HistEvaluator<GradientSumT, CPUExpandEntry> evaluator_;
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram_builder_;
GenericParameter const *ctx_;
std::vector<ApproxRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
RegTree *p_last_tree_{nullptr};
common::Monitor *monitor_;
size_t n_batches_{0};
// Cache for histogram cuts.
common::HistogramCuts feature_values_;
public:
void InitData(DMatrix *p_fmat, common::Span<float> hess) {
monitor_->Start(__func__);
n_batches_ = 0;
int32_t n_total_bins = 0;
partitioner_.clear();
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, param_.max_bin, hess, true})) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
feature_values_ = page.cut;
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(page.Size(), page.base_rowid);
n_batches_++;
}
histogram_builder_.Reset(n_total_bins,
BatchParam{GenericParameter::kCpuId, param_.max_bin, hess},
ctx_->Threads(), n_batches_, rabit::IsDistributed());
monitor_->Stop(__func__);
}
CPUExpandEntry InitRoot(DMatrix *p_fmat, std::vector<GradientPair> const &gpair,
common::Span<float> hess, RegTree *p_tree) {
monitor_->Start(__func__);
CPUExpandEntry best;
best.nid = RegTree::kRoot;
best.depth = 0;
GradStats root_sum;
for (auto const &g : gpair) {
root_sum.Add(g);
}
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2);
std::vector<CPUExpandEntry> nodes{best};
size_t i = 0;
auto space = this->ConstructHistSpace(nodes);
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
{}, gpair);
i++;
}
auto weight = evaluator_.InitRoot(root_sum);
p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess();
p_tree->Stat(RegTree::kRoot).base_weight = weight;
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
auto const &histograms = histogram_builder_.Histogram();
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes);
monitor_->Stop(__func__);
return nodes.front();
}
void UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) {
monitor_->Start(__func__);
// Caching prediction seems redundant for approx tree method, as sketching takes up
// majority of training time.
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
CHECK(p_last_tree_);
size_t n_nodes = p_last_tree_->GetNodes().size();
auto evaluator = evaluator_.Evaluator();
auto const &tree = *p_last_tree_;
auto const &snode = evaluator_.Stats();
for (auto &part : partitioner_) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) {
if (tree[nidx].IsLeaf()) {
const auto rowset = part[nidx];
auto const &stats = snode.at(nidx);
auto leaf_value =
evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate;
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
}
monitor_->Stop(__func__);
}
// Construct a work space for building histogram. Eventually we should move this
// function into histogram builder once hist tree method supports external memory.
common::BlockedSpace2d ConstructHistSpace(
std::vector<CPUExpandEntry> const &nodes_to_build) const {
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
for (auto const &partition : partitioner_) {
size_t k = 0;
for (auto node : nodes_to_build) {
auto n_rows_in_node = partition.Partitions()[node.nid].Size();
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
k++;
}
}
common::BlockedSpace2d space{nodes_to_build.size(),
[&](size_t nidx_in_set) { return partition_size[nidx_in_set]; },
256};
return space;
}
void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree,
std::vector<CPUExpandEntry> const &valid_candidates,
std::vector<GradientPair> const &gpair, common::Span<float> hess) {
monitor_->Start(__func__);
std::vector<CPUExpandEntry> nodes_to_build;
std::vector<CPUExpandEntry> nodes_to_sub;
for (auto const &c : valid_candidates) {
auto left_nidx = (*p_tree)[c.nid].LeftChild();
auto right_nidx = (*p_tree)[c.nid].RightChild();
auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess();
auto build_nidx = left_nidx;
auto subtract_nidx = right_nidx;
if (fewer_right) {
std::swap(build_nidx, subtract_nidx);
}
nodes_to_build.push_back(CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}});
nodes_to_sub.push_back(CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}});
}
size_t i = 0;
auto space = this->ConstructHistSpace(nodes_to_build);
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
nodes_to_build, nodes_to_sub, gpair);
i++;
}
monitor_->Stop(__func__);
}
public:
explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx,
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
common::Monitor *monitor)
: param_{std::move(param)},
col_sampler_{std::move(column_sampler)},
evaluator_{param_, info, ctx->Threads(), col_sampler_, task},
ctx_{ctx},
monitor_{monitor} {}
void UpdateTree(RegTree *p_tree, std::vector<GradientPair> const &gpair, common::Span<float> hess,
DMatrix *p_fmat) {
p_last_tree_ = p_tree;
this->InitData(p_fmat, hess);
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
auto &tree = *p_tree;
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
bst_node_t num_leaves = 1;
auto expand_set = driver.Pop();
while (!expand_set.empty()) {
// candidates that can be further splited.
std::vector<CPUExpandEntry> valid_candidates;
// candidates that can be applied.
std::vector<CPUExpandEntry> applied;
for (auto const &candidate : expand_set) {
if (!candidate.IsValid(param_, num_leaves)) {
continue;
}
evaluator_.ApplyTreeSplit(candidate, p_tree);
applied.push_back(candidate);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) {
valid_candidates.emplace_back(candidate);
}
}
monitor_->Start("UpdatePosition");
size_t i = 0;
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree);
i++;
}
monitor_->Stop("UpdatePosition");
std::vector<CPUExpandEntry> best_splits;
if (!valid_candidates.empty()) {
this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair, hess);
for (auto const &candidate : valid_candidates) {
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}};
CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}};
best_splits.push_back(l_best);
best_splits.push_back(r_best);
}
auto const &histograms = histogram_builder_.Histogram();
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
monitor_->Start("EvaluateSplits");
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits);
monitor_->Stop("EvaluateSplits");
}
driver.Push(best_splits.begin(), best_splits.end());
expand_set = driver.Pop();
}
}
};
/**
* \brief Implementation for the approx tree method. It constructs quantile for every
* iteration.
*/
class GlobalApproxUpdater : public TreeUpdater {
TrainParam param_;
common::Monitor monitor_;
CPUHistMakerTrainParam hist_param_;
// specializations for different histogram precision.
std::unique_ptr<GloablApproxBuilder<float>> f32_impl_;
std::unique_ptr<GloablApproxBuilder<double>> f64_impl_;
// pointer to the last DMatrix, used for update prediction cache.
DMatrix *cached_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
ObjInfo task_;
public:
explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); }
void Configure(const Args &args) override {
param_.UpdateAllowUnknown(args);
hist_param_.UpdateAllowUnknown(args);
}
void LoadConfig(Json const &in) override {
auto const &config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_);
FromJson(config.at("hist_param"), &this->hist_param_);
}
void SaveConfig(Json *p_out) const override {
auto &out = *p_out;
out["train_param"] = ToJson(param_);
out["hist_param"] = ToJson(hist_param_);
}
void InitData(TrainParam const &param, HostDeviceVector<GradientPair> *gpair,
std::vector<GradientPair> *sampled) {
auto const &h_gpair = gpair->HostVector();
sampled->resize(h_gpair.size());
std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin());
auto &rnd = common::GlobalRandom();
if (param.subsample != 1.0) {
CHECK(param.sampling_method != TrainParam::kGradientBased)
<< "Gradient based sampling is not supported for approx tree method.";
std::bernoulli_distribution coin_flip(param.subsample);
std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) {
if (coin_flip(rnd)) {
return g;
} else {
return GradientPair{};
}
});
}
}
char const *Name() const override { return "grow_histmaker"; }
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *m,
const std::vector<RegTree *> &trees) override {
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
if (hist_param_.single_precision_histogram) {
f32_impl_ = std::make_unique<GloablApproxBuilder<float>>(param_, m->Info(), tparam_,
column_sampler_, task_, &monitor_);
} else {
f64_impl_ = std::make_unique<GloablApproxBuilder<double>>(param_, m->Info(), tparam_,
column_sampler_, task_, &monitor_);
}
std::vector<GradientPair> h_gpair;
InitData(param_, gpair, &h_gpair);
// Obtain the hessian values for weighted sketching
std::vector<float> hess(h_gpair.size());
std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(),
[](auto g) { return g.GetHess(); });
cached_ = m;
for (auto p_tree : trees) {
if (hist_param_.single_precision_histogram) {
this->f32_impl_->UpdateTree(p_tree, h_gpair, hess, m);
} else {
this->f64_impl_->UpdateTree(p_tree, h_gpair, hess, m);
}
}
param_.learning_rate = lr;
}
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
if (data != cached_ || (!this->f32_impl_ && !this->f64_impl_)) {
return false;
}
if (hist_param_.single_precision_histogram) {
this->f32_impl_->UpdatePredictionCache(data, out_preds);
} else {
this->f64_impl_->UpdatePredictionCache(data, out_preds);
}
return true;
}
};
DMLC_REGISTRY_FILE_TAG(grow_histmaker);
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
.describe(
"Tree constructor that uses approximate histogram construction "
"for each node.")
.set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); });
} // namespace tree
} // namespace xgboost

View File

@ -641,126 +641,10 @@ class CQHistMaker: public HistMaker {
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs_;
};
// global proposal
class GlobalProposalHistMaker: public CQHistMaker {
public:
char const* Name() const override {
return "grow_histmaker";
}
protected:
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_feature_t> &fset,
const RegTree &tree) override {
if (this->qexpand_.size() == 1) {
cached_rptr_.clear();
cached_cut_.clear();
}
if (cached_rptr_.size() == 0) {
CHECK_EQ(this->qexpand_.size(), 1U);
CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree);
cached_rptr_ = this->wspace_.rptr;
cached_cut_ = this->wspace_.cut;
} else {
this->wspace_.cut.clear();
this->wspace_.rptr.clear();
this->wspace_.rptr.push_back(0);
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
this->wspace_.rptr.push_back(
this->wspace_.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
}
this->wspace_.cut.insert(this->wspace_.cut.end(), cached_cut_.begin(), cached_cut_.end());
}
CHECK_EQ(this->wspace_.rptr.size(),
(fset.size() + 1) * this->qexpand_.size() + 1);
CHECK_EQ(this->wspace_.rptr.back(), this->wspace_.cut.size());
}
}
// code to create histogram
void CreateHist(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_feature_t> &fset,
const RegTree &tree) override {
const MetaInfo &info = p_fmat->Info();
// fill in reverse map
this->feat2workindex_.resize(tree.param.num_feature);
this->work_set_ = fset;
std::fill(this->feat2workindex_.begin(), this->feat2workindex_.end(), -1);
for (size_t i = 0; i < fset.size(); ++i) {
this->feat2workindex_[fset[i]] = static_cast<int>(i);
}
// start to work
this->wspace_.Configure(1);
// to gain speedup in recovery
{
this->thread_hist_.resize(omp_get_max_threads());
// TWOPASS: use the real set + split set in the column iteration.
this->SetDefaultPostion(p_fmat, tree);
this->work_set_.insert(this->work_set_.end(), this->fsplit_set_.begin(),
this->fsplit_set_.end());
XGBOOST_PARALLEL_SORT(this->work_set_.begin(), this->work_set_.end(),
std::less<>{});
this->work_set_.resize(
std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());
// start accumulating statistics
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
auto page = batch.GetView();
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
dmlc::OMPException exc;
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
exc.Run([&]() {
int fid = this->work_set_[i];
int offset = this->feat2workindex_[fid];
if (offset >= 0) {
this->UpdateHistCol(gpair, page[fid], info, tree,
fset, offset,
&this->thread_hist_[omp_get_thread_num()]);
}
});
}
exc.Rethrow();
}
// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&(this->thread_stats_), &(this->node_stats_));
for (const int nid : this->qexpand_) {
const int wid = this->node2workindex_[nid];
this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
.data[0] = this->node_stats_[nid];
}
}
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
this->wspace_.hset[0].data.size());
}
// cached unit pointer
std::vector<unsigned> cached_rptr_;
// cached cut value.
std::vector<bst_float> cached_cut_;
};
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
.describe("Tree constructor that uses approximate histogram construction.")
.set_body([](ObjInfo) {
return new CQHistMaker();
});
// The updater for approx tree method.
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
.describe("Tree constructor that uses approximate global of histogram construction.")
.set_body([](ObjInfo) {
return new GlobalProposalHistMaker();
});
} // namespace tree
} // namespace xgboost

View File

@ -35,7 +35,7 @@ TEST(GBTree, SelectTreeMethod) {
gbtree.Configure(args);
auto const& tparam = gbtree.GetTrainParam();
gbtree.Configure({{"tree_method", "approx"}});
ASSERT_EQ(tparam.updater_seq, "grow_histmaker,prune");
ASSERT_EQ(tparam.updater_seq, "grow_histmaker");
gbtree.Configure({{"tree_method", "exact"}});
ASSERT_EQ(tparam.updater_seq, "grow_colmaker,prune");
gbtree.Configure({{"tree_method", "hist"}});

View File

@ -72,5 +72,58 @@ TEST(Approx, Partitioner) {
}
}
}
TEST(Approx, PredictionCache) {
size_t n_samples = 2048, n_features = 13;
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
{
omp_set_num_threads(1);
GenericParameter ctx;
ctx.InitAllowUnknown(Args{{"nthread", "8"}});
std::unique_ptr<TreeUpdater> approx{
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
RegTree tree;
std::vector<RegTree *> trees{&tree};
auto gpair = GenerateRandomGradients(n_samples);
approx->Configure(Args{{"max_bin", "64"}});
approx->Update(&gpair, Xy.get(), trees);
HostDeviceVector<float> out_prediction_cached;
out_prediction_cached.Resize(n_samples);
auto cache = linalg::VectorView<float>{
out_prediction_cached.HostSpan(), {out_prediction_cached.Size()}, GenericParameter::kCpuId};
ASSERT_TRUE(approx->UpdatePredictionCache(Xy.get(), cache));
}
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->SetParam("tree_method", "approx");
learner->SetParam("nthread", "0");
learner->Configure();
for (size_t i = 0; i < 8; ++i) {
learner->UpdateOneIter(i, Xy);
}
HostDeviceVector<float> out_prediction_cached;
learner->Predict(Xy, false, &out_prediction_cached, 0, 0);
Json model{Object()};
learner->SaveModel(&model);
HostDeviceVector<float> out_prediction;
{
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->LoadModel(model);
learner->Predict(Xy, false, &out_prediction, 0, 0);
}
auto const h_predt_cached = out_prediction_cached.ConstHostSpan();
auto const h_predt = out_prediction.ConstHostSpan();
ASSERT_EQ(h_predt.size(), h_predt_cached.size());
for (size_t i = 0; i < h_predt.size(); ++i) {
ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps);
}
}
} // namespace tree
} // namespace xgboost

View File

@ -315,57 +315,6 @@ TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl();
}
// gamma is an alias of min_split_loss
int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPair>* gpair) {
Args args {
{"max_depth", "1"},
{"max_leaves", "0"},
// Disable all other parameters.
{"colsample_bynode", "1"},
{"colsample_bylevel", "1"},
{"colsample_bytree", "1"},
{"min_child_weight", "0.01"},
{"reg_alpha", "0"},
{"reg_lambda", "0"},
{"max_delta_step", "0"},
// test gamma
{"gamma", std::to_string(gamma)}
};
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
GenericParameter generic_param(CreateEmptyGenericParam(0));
hist_maker.Configure(args, &generic_param);
RegTree tree;
hist_maker.Update(gpair, dmat, {&tree});
auto n_nodes = tree.NumExtraNodes();
return n_nodes;
}
TEST(GpuHist, MinSplitLoss) {
constexpr size_t kRows = 32;
constexpr size_t kCols = 16;
constexpr float kSparsity = 0.6;
auto dmat = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix();
auto gpair = GenerateRandomGradients(kRows);
{
int32_t n_nodes = TestMinSplitLoss(dmat.get(), 0.01, &gpair);
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
// when writing this test, and only used for testing larger gamma (below) does prevent
// building tree.
ASSERT_EQ(n_nodes, 2);
}
{
int32_t n_nodes = TestMinSplitLoss(dmat.get(), 100.0, &gpair);
// No new nodes with gamma == 100.
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
}
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
size_t gpu_page_size, RegTree* tree,
HostDeviceVector<bst_float>* preds, float subsample = 1.0f,

View File

@ -61,7 +61,7 @@ class TestGrowPolicy : public ::testing::Test {
}
};
TEST_F(TestGrowPolicy, DISABLED_Approx) {
TEST_F(TestGrowPolicy, Approx) {
this->TestTreeGrowPolicy("approx", "depthwise");
this->TestTreeGrowPolicy("approx", "lossguide");
}

View File

@ -114,4 +114,70 @@ TEST_F(UpdaterEtaTest, Approx) { this->RunTest("grow_histmaker"); }
#if defined(XGBOOST_USE_CUDA)
TEST_F(UpdaterEtaTest, GpuHist) { this->RunTest("grow_gpu_hist"); }
#endif // defined(XGBOOST_USE_CUDA)
class TestMinSplitLoss : public ::testing::Test {
std::shared_ptr<DMatrix> dmat_;
HostDeviceVector<GradientPair> gpair_;
void SetUp() override {
constexpr size_t kRows = 32;
constexpr size_t kCols = 16;
constexpr float kSparsity = 0.6;
dmat_ = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix();
gpair_ = GenerateRandomGradients(kRows);
}
int32_t Update(std::string updater, float gamma) {
Args args{{"max_depth", "1"},
{"max_leaves", "0"},
// Disable all other parameters.
{"colsample_bynode", "1"},
{"colsample_bylevel", "1"},
{"colsample_bytree", "1"},
{"min_child_weight", "0.01"},
{"reg_alpha", "0"},
{"reg_lambda", "0"},
{"max_delta_step", "0"},
// test gamma
{"gamma", std::to_string(gamma)}};
GenericParameter generic_param(CreateEmptyGenericParam(0));
auto up = std::unique_ptr<TreeUpdater>{
TreeUpdater::Create(updater, &generic_param, ObjInfo{ObjInfo::kRegression})};
up->Configure(args);
RegTree tree;
up->Update(&gpair_, dmat_.get(), {&tree});
auto n_nodes = tree.NumExtraNodes();
return n_nodes;
}
public:
void RunTest(std::string updater) {
{
int32_t n_nodes = Update(updater, 0.01);
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
// when writing this test, and only used for testing larger gamma (below) does prevent
// building tree.
ASSERT_EQ(n_nodes, 2);
}
{
int32_t n_nodes = Update(updater, 100.0);
// No new nodes with gamma == 100.
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
}
}
};
/* Exact tree method requires a pruner as an additional updater, so not tested here. */
TEST_F(TestMinSplitLoss, Approx) { this->RunTest("grow_histmaker"); }
TEST_F(TestMinSplitLoss, Hist) { this->RunTest("grow_quantile_histmaker"); }
#if defined(XGBOOST_USE_CUDA)
TEST_F(TestMinSplitLoss, GpuHist) { this->RunTest("grow_gpu_hist"); }
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost

View File

@ -7,6 +7,8 @@ from hypothesis import given, strategies, assume, settings, note
sys.path.append("tests/python")
import testing as tm
import test_updaters as test_up
parameter_strategy = strategies.fixed_dictionaries({
'max_depth': strategies.integers(0, 11),
@ -32,6 +34,8 @@ def train_result(param, dmat, num_rounds):
class TestGPUUpdaters:
cputest = test_up.TestTreeMethod()
@given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy)
@settings(deadline=None)
def test_gpu_hist(self, param, num_rounds, dataset):
@ -41,51 +45,12 @@ class TestGPUUpdaters:
note(result)
assert tm.non_increasing(result["train"][dataset.metric])
def run_categorical_basic(self, rows, cols, rounds, cats):
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
by_etl_results = {}
by_builtin_results = {}
parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)
# There are guidelines on how to specify tolerance based on considering output as
# random variables. But in here the tree construction is extremely sensitive to
# floating point errors. An 1e-5 error in a histogram bin can lead to an entirely
# different tree. So even though the test is quite lenient, hypothesis can still
# pick up falsifying examples from time to time.
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats):
self.run_categorical_basic(rows, cols, rounds, cats)
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
def test_categorical_32_cat(self):
'''32 hits the bound of integer bitset, so special test'''
@ -93,7 +58,7 @@ class TestGPUUpdaters:
cols = 10
cats = 32
rounds = 4
self.run_categorical_basic(rows, cols, rounds, cats)
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
def test_invalid_categorical(self):
import cupy as cp

View File

@ -63,7 +63,6 @@ training_dset = xgb.DMatrix(x, label=y)
class TestMonotoneConstraints:
def test_monotone_constraints_for_exact_tree_method(self):
# first check monotonicity for the 'exact' tree method
@ -76,32 +75,23 @@ class TestMonotoneConstraints:
)
assert is_correctly_constrained(constrained_exact_method)
def test_monotone_constraints_for_depthwise_hist_tree_method(self):
# next check monotonicity for the 'hist' tree method
params_for_constrained_hist_method = {
'tree_method': 'hist', 'verbosity': 1,
'monotone_constraints': '(1, -1)'
}
constrained_hist_method = xgb.train(
params_for_constrained_hist_method, training_dset
@pytest.mark.parametrize(
"tree_method,policy",
[
("hist", "depthwise"),
("approx", "depthwise"),
("hist", "lossguide"),
("approx", "lossguide"),
],
)
assert is_correctly_constrained(constrained_hist_method)
def test_monotone_constraints_for_lossguide_hist_tree_method(self):
# next check monotonicity for the 'hist' tree method
params_for_constrained_hist_method = {
'tree_method': 'hist', 'verbosity': 1,
'grow_policy': 'lossguide',
'monotone_constraints': '(1, -1)'
def test_monotone_constraints(self, tree_method: str, policy: str) -> None:
params_for_constrained = {
"tree_method": tree_method,
"grow_policy": policy,
"monotone_constraints": "(1, -1)",
}
constrained_hist_method = xgb.train(
params_for_constrained_hist_method, training_dset
)
assert is_correctly_constrained(constrained_hist_method)
constrained = xgb.train(params_for_constrained, training_dset)
assert is_correctly_constrained(constrained)
@pytest.mark.parametrize('format', [dict, list])
def test_monotone_constraints_feature_names(self, format):

View File

@ -45,14 +45,20 @@ class TestTreeMethod:
result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result['train'][dataset.metric])
@given(exact_parameter_strategy, strategies.integers(1, 20),
tm.dataset_strategy)
@given(
exact_parameter_strategy,
hist_parameter_strategy,
strategies.integers(1, 20),
tm.dataset_strategy,
)
@settings(deadline=None)
def test_approx(self, param, num_rounds, dataset):
param['tree_method'] = 'approx'
def test_approx(self, param, hist_param, num_rounds, dataset):
param["tree_method"] = "approx"
param = dataset.set_params(param)
param.update(hist_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result['train'][dataset.metric], 1e-3)
note(result)
assert tm.non_increasing(result["train"][dataset.metric])
@pytest.mark.skipif(**tm.no_sklearn())
def test_pruner(self):
@ -126,3 +132,53 @@ class TestTreeMethod:
y = [1000000., 0., 0., 500000.]
w = [0, 0, 1, 0]
model.fit(X, y, sample_weight=w)
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
by_etl_results = {}
by_builtin_results = {}
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
# Use one-hot exclusively
parameters = {
"tree_method": tree_method, "predictor": predictor, "max_cat_to_onehot": 9999
}
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)
# There are guidelines on how to specify tolerance based on considering output as
# random variables. But in here the tree construction is extremely sensitive to
# floating point errors. An 1e-5 error in a histogram bin can lead to an entirely
# different tree. So even though the test is quite lenient, hypothesis can still
# pick up falsifying examples from time to time.
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats):
self.run_categorical_basic(rows, cols, rounds, cats, "approx")

View File

@ -1184,9 +1184,13 @@ class TestWithDask:
for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port_env = arg.decode('utf-8')
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8")
port = port_env.split('=')
env = os.environ.copy()
env[port[0]] = port[1]
uri = uri_env.split("=")
env["DMLC_TRACKER_URI"] = uri[1]
return subprocess.run([str(exe), test], env=env, capture_output=True)
with LocalCluster(n_workers=4) as cluster:
@ -1210,11 +1214,13 @@ class TestWithDask:
@pytest.mark.gtest
def test_quantile_basic(self) -> None:
self.run_quantile('DistributedBasic')
self.run_quantile('SortedDistributedBasic')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile(self) -> None:
self.run_quantile('Distributed')
self.run_quantile('SortedDistributed')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
@ -1252,13 +1258,17 @@ class TestWithDask:
for i in range(kCols):
fw[i] *= float(i)
fw = da.from_array(fw)
poly_increasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor)
poly_increasing = run_feature_weights(
X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor
)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(kCols - i)
fw = da.from_array(fw)
poly_decreasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor)
poly_decreasing = run_feature_weights(
X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor
)
# Approxmated test, this is dependent on the implementation of random
# number generator in std library.

View File

@ -1031,10 +1031,10 @@ def test_pandas_input():
np.array([0, 1]))
def run_feature_weights(X, y, fw, model=xgb.XGBRegressor):
def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor):
with tempfile.TemporaryDirectory() as tmpdir:
colsample_bynode = 0.5
reg = model(tree_method='hist', colsample_bynode=colsample_bynode)
reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode)
reg.fit(X, y, feature_weights=fw)
model_path = os.path.join(tmpdir, 'model.json')
@ -1069,7 +1069,8 @@ def run_feature_weights(X, y, fw, model=xgb.XGBRegressor):
return w
def test_feature_weights():
@pytest.mark.parametrize("tree_method", ["approx", "hist"])
def test_feature_weights(tree_method):
kRows = 512
kCols = 64
X = rng.randn(kRows, kCols)
@ -1078,12 +1079,12 @@ def test_feature_weights():
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)
poly_increasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
poly_increasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(kCols - i)
poly_decreasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
poly_decreasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor)
# Approxmated test, this is dependent on the implementation of random
# number generator in std library.