Rewrite approx (#7214)
This PR rewrites the approx tree method to use codebase from hist for better performance and code sharing. The rewrite has many benefits: - Support for both `max_leaves` and `max_depth`. - Support for `grow_policy`. - Support for mono constraint. - Support for feature weights. - Support for easier bin configuration (`max_bin`). - Support for categorical data. - Faster performance for most of the datasets. (many times faster) - Support for prediction cache. - Significantly better performance for external memory. - Unites the code base between approx and hist.
This commit is contained in:
parent
ed95e77752
commit
001503186c
@ -57,6 +57,7 @@
|
||||
#include "../src/tree/updater_refresh.cc"
|
||||
#include "../src/tree/updater_sync.cc"
|
||||
#include "../src/tree/updater_histmaker.cc"
|
||||
#include "../src/tree/updater_approx.cc"
|
||||
#include "../src/tree/constraints.cc"
|
||||
|
||||
// linear
|
||||
|
||||
@ -3,7 +3,8 @@ Getting started with categorical data
|
||||
=====================================
|
||||
|
||||
Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has
|
||||
experimental support for one-hot encoding based tree split.
|
||||
experimental support for one-hot encoding based tree split, and in 1.6 `approx` supported
|
||||
was added.
|
||||
|
||||
In before, users need to run an encoder themselves before passing the data into XGBoost,
|
||||
which creates a sparse matrix and potentially increase memory usage. This demo showcases
|
||||
|
||||
@ -154,7 +154,7 @@ Parameters for Tree Booster
|
||||
|
||||
* ``sketch_eps`` [default=0.03]
|
||||
|
||||
- Only used for ``tree_method=approx``.
|
||||
- Only used for ``updater=grow_local_histmaker``.
|
||||
- This roughly translates into ``O(1 / sketch_eps)`` number of bins.
|
||||
Compared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy.
|
||||
- Usually user does not have to tune this.
|
||||
@ -238,13 +238,27 @@ Parameters for Tree Booster
|
||||
list is a group of indices of features that are allowed to interact with each other.
|
||||
See :doc:`/tutorials/feature_interaction_constraint` for more information.
|
||||
|
||||
Additional parameters for ``hist`` and ``gpu_hist`` tree method
|
||||
================================================================
|
||||
Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
|
||||
===========================================================================
|
||||
|
||||
* ``single_precision_histogram``, [default= ``false``]
|
||||
|
||||
- Use single precision to build histograms instead of double precision.
|
||||
|
||||
Additional parameters for ``approx`` tree method
|
||||
================================================
|
||||
|
||||
* ``max_cat_to_onehot``
|
||||
|
||||
.. versionadded:: 1.6
|
||||
|
||||
.. note:: The support for this parameter is experimental.
|
||||
|
||||
- A threshold for deciding whether XGBoost should use one-hot encoding based split for
|
||||
categorical data. When number of categories is lesser than the threshold then one-hot
|
||||
encoding is chosen, otherwise the categories will be partitioned into children nodes.
|
||||
Only relevant for regression and binary classification with `approx` tree method.
|
||||
|
||||
Additional parameters for Dart Booster (``booster=dart``)
|
||||
=========================================================
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "reg:squarederror",
|
||||
"max_bin" -> 16,
|
||||
"max_bin" -> 64,
|
||||
"tree_method" -> treeMethod)
|
||||
|
||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||
|
||||
@ -267,6 +267,16 @@ __model_doc = f'''
|
||||
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||
save_best=True)]
|
||||
|
||||
max_cat_to_onehot : bool
|
||||
|
||||
.. versionadded:: 1.6.0
|
||||
|
||||
A threshold for deciding whether XGBoost should use one-hot encoding based split
|
||||
for categorical data. When number of categories is lesser than the threshold then
|
||||
one-hot encoding is chosen, otherwise the categories will be partitioned into
|
||||
children nodes. Only relevant for regression and binary classification and
|
||||
`approx` tree method.
|
||||
|
||||
kwargs : dict, optional
|
||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters
|
||||
can be found :doc:`here </parameter>`.
|
||||
@ -483,6 +493,7 @@ class XGBModel(XGBModelBase):
|
||||
eval_metric: Optional[Union[str, List[str], Callable]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None,
|
||||
max_cat_to_onehot: Optional[int] = None,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
if not SKLEARN_INSTALLED:
|
||||
@ -522,6 +533,7 @@ class XGBModel(XGBModelBase):
|
||||
self.eval_metric = eval_metric
|
||||
self.early_stopping_rounds = early_stopping_rounds
|
||||
self.callbacks = callbacks
|
||||
self.max_cat_to_onehot = max_cat_to_onehot
|
||||
if kwargs:
|
||||
self.kwargs = kwargs
|
||||
|
||||
@ -800,8 +812,8 @@ class XGBModel(XGBModelBase):
|
||||
_duplicated("callbacks")
|
||||
callbacks = self.callbacks if self.callbacks is not None else callbacks
|
||||
|
||||
# lastly check categorical data support.
|
||||
if self.enable_categorical and params.get("tree_method", None) != "gpu_hist":
|
||||
tree_method = params.get("tree_method", None)
|
||||
if self.enable_categorical and tree_method not in ("gpu_hist", "approx"):
|
||||
raise ValueError(
|
||||
"Experimental support for categorical data is not implemented for"
|
||||
" current tree method yet."
|
||||
@ -876,8 +888,7 @@ class XGBModel(XGBModelBase):
|
||||
feature_weights :
|
||||
Weight for each feature, defines the probability of each feature being
|
||||
selected when colsample is being used. All values must be greater than 0,
|
||||
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
|
||||
`exact` tree methods.
|
||||
otherwise a `ValueError` is thrown.
|
||||
|
||||
callbacks :
|
||||
.. deprecated: 1.6.0
|
||||
@ -1750,8 +1761,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
feature_weights :
|
||||
Weight for each feature, defines the probability of each feature being
|
||||
selected when colsample is being used. All values must be greater than 0,
|
||||
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
|
||||
`exact` tree methods.
|
||||
otherwise a `ValueError` is thrown.
|
||||
|
||||
callbacks :
|
||||
.. deprecated: 1.6.0
|
||||
|
||||
@ -130,8 +130,7 @@ class BlockedSpace2d {
|
||||
template <typename Func>
|
||||
void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
|
||||
const size_t num_blocks_in_space = space.Size();
|
||||
nthreads = std::min(nthreads, omp_get_max_threads());
|
||||
nthreads = std::max(nthreads, 1);
|
||||
CHECK_GE(nthreads, 1);
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(nthreads)
|
||||
@ -277,9 +276,10 @@ inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
|
||||
|
||||
inline int32_t OmpGetNumThreads(int32_t n_threads) {
|
||||
if (n_threads <= 0) {
|
||||
n_threads = omp_get_num_procs();
|
||||
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
|
||||
}
|
||||
n_threads = std::min(n_threads, OmpGetThreadLimit());
|
||||
n_threads = std::max(n_threads, 1);
|
||||
return n_threads;
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@ -168,7 +168,7 @@ void GBTree::ConfigureUpdaters() {
|
||||
// calling this function.
|
||||
break;
|
||||
case TreeMethod::kApprox:
|
||||
tparam_.updater_seq = "grow_histmaker,prune";
|
||||
tparam_.updater_seq = "grow_histmaker";
|
||||
break;
|
||||
case TreeMethod::kExact:
|
||||
tparam_.updater_seq = "grow_colmaker,prune";
|
||||
|
||||
@ -38,7 +38,7 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
||||
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
|
||||
int grow_policy;
|
||||
|
||||
uint32_t max_cat_to_onehot{1};
|
||||
uint32_t max_cat_to_onehot{4};
|
||||
|
||||
//----- the rest parameters are less important ----
|
||||
// minimum amount of hessian(weight) allowed in a child
|
||||
|
||||
@ -973,6 +973,7 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
|
||||
}
|
||||
size_t size = categories.size() - begin;
|
||||
categories_sizes.emplace_back(static_cast<Integer::Int>(size));
|
||||
CHECK_NE(size, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ DMLC_REGISTRY_LINK_TAG(updater_refresh);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_prune);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_quantile_hist);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_histmaker);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_approx);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_sync);
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);
|
||||
|
||||
369
src/tree/updater_approx.cc
Normal file
369
src/tree/updater_approx.cc
Normal file
@ -0,0 +1,369 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
*
|
||||
* \brief Implementation for the approx tree method.
|
||||
*/
|
||||
#include "updater_approx.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/random.h"
|
||||
#include "../data/gradient_index.h"
|
||||
#include "constraints.h"
|
||||
#include "driver.h"
|
||||
#include "hist/evaluate_splits.h"
|
||||
#include "hist/histogram.h"
|
||||
#include "hist/param.h"
|
||||
#include "param.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_approx);
|
||||
|
||||
template <typename GradientSumT>
|
||||
class GloablApproxBuilder {
|
||||
protected:
|
||||
TrainParam param_;
|
||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||
HistEvaluator<GradientSumT, CPUExpandEntry> evaluator_;
|
||||
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram_builder_;
|
||||
GenericParameter const *ctx_;
|
||||
|
||||
std::vector<ApproxRowPartitioner> partitioner_;
|
||||
// Pointer to last updated tree, used for update prediction cache.
|
||||
RegTree *p_last_tree_{nullptr};
|
||||
common::Monitor *monitor_;
|
||||
size_t n_batches_{0};
|
||||
// Cache for histogram cuts.
|
||||
common::HistogramCuts feature_values_;
|
||||
|
||||
public:
|
||||
void InitData(DMatrix *p_fmat, common::Span<float> hess) {
|
||||
monitor_->Start(__func__);
|
||||
n_batches_ = 0;
|
||||
int32_t n_total_bins = 0;
|
||||
partitioner_.clear();
|
||||
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(
|
||||
{GenericParameter::kCpuId, param_.max_bin, hess, true})) {
|
||||
if (n_total_bins == 0) {
|
||||
n_total_bins = page.cut.TotalBins();
|
||||
feature_values_ = page.cut;
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(page.Size(), page.base_rowid);
|
||||
n_batches_++;
|
||||
}
|
||||
|
||||
histogram_builder_.Reset(n_total_bins,
|
||||
BatchParam{GenericParameter::kCpuId, param_.max_bin, hess},
|
||||
ctx_->Threads(), n_batches_, rabit::IsDistributed());
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
CPUExpandEntry InitRoot(DMatrix *p_fmat, std::vector<GradientPair> const &gpair,
|
||||
common::Span<float> hess, RegTree *p_tree) {
|
||||
monitor_->Start(__func__);
|
||||
CPUExpandEntry best;
|
||||
best.nid = RegTree::kRoot;
|
||||
best.depth = 0;
|
||||
GradStats root_sum;
|
||||
for (auto const &g : gpair) {
|
||||
root_sum.Add(g);
|
||||
}
|
||||
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2);
|
||||
std::vector<CPUExpandEntry> nodes{best};
|
||||
size_t i = 0;
|
||||
auto space = this->ConstructHistSpace(nodes);
|
||||
for (auto const &page :
|
||||
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
|
||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
|
||||
{}, gpair);
|
||||
i++;
|
||||
}
|
||||
|
||||
auto weight = evaluator_.InitRoot(root_sum);
|
||||
p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess();
|
||||
p_tree->Stat(RegTree::kRoot).base_weight = weight;
|
||||
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
|
||||
|
||||
auto const &histograms = histogram_builder_.Histogram();
|
||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes);
|
||||
monitor_->Stop(__func__);
|
||||
|
||||
return nodes.front();
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) {
|
||||
monitor_->Start(__func__);
|
||||
// Caching prediction seems redundant for approx tree method, as sketching takes up
|
||||
// majority of training time.
|
||||
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
|
||||
CHECK(p_last_tree_);
|
||||
|
||||
size_t n_nodes = p_last_tree_->GetNodes().size();
|
||||
|
||||
auto evaluator = evaluator_.Evaluator();
|
||||
auto const &tree = *p_last_tree_;
|
||||
auto const &snode = evaluator_.Stats();
|
||||
for (auto &part : partitioner_) {
|
||||
CHECK_EQ(part.Size(), n_nodes);
|
||||
common::BlockedSpace2d space(
|
||||
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
|
||||
common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) {
|
||||
if (tree[nidx].IsLeaf()) {
|
||||
const auto rowset = part[nidx];
|
||||
auto const &stats = snode.at(nidx);
|
||||
auto leaf_value =
|
||||
evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate;
|
||||
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds(*it) += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
// Construct a work space for building histogram. Eventually we should move this
|
||||
// function into histogram builder once hist tree method supports external memory.
|
||||
common::BlockedSpace2d ConstructHistSpace(
|
||||
std::vector<CPUExpandEntry> const &nodes_to_build) const {
|
||||
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
|
||||
for (auto const &partition : partitioner_) {
|
||||
size_t k = 0;
|
||||
for (auto node : nodes_to_build) {
|
||||
auto n_rows_in_node = partition.Partitions()[node.nid].Size();
|
||||
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
|
||||
k++;
|
||||
}
|
||||
}
|
||||
common::BlockedSpace2d space{nodes_to_build.size(),
|
||||
[&](size_t nidx_in_set) { return partition_size[nidx_in_set]; },
|
||||
256};
|
||||
return space;
|
||||
}
|
||||
|
||||
void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree,
|
||||
std::vector<CPUExpandEntry> const &valid_candidates,
|
||||
std::vector<GradientPair> const &gpair, common::Span<float> hess) {
|
||||
monitor_->Start(__func__);
|
||||
std::vector<CPUExpandEntry> nodes_to_build;
|
||||
std::vector<CPUExpandEntry> nodes_to_sub;
|
||||
|
||||
for (auto const &c : valid_candidates) {
|
||||
auto left_nidx = (*p_tree)[c.nid].LeftChild();
|
||||
auto right_nidx = (*p_tree)[c.nid].RightChild();
|
||||
auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess();
|
||||
|
||||
auto build_nidx = left_nidx;
|
||||
auto subtract_nidx = right_nidx;
|
||||
if (fewer_right) {
|
||||
std::swap(build_nidx, subtract_nidx);
|
||||
}
|
||||
nodes_to_build.push_back(CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}});
|
||||
nodes_to_sub.push_back(CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}});
|
||||
}
|
||||
|
||||
size_t i = 0;
|
||||
auto space = this->ConstructHistSpace(nodes_to_build);
|
||||
for (auto const &page :
|
||||
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
|
||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
||||
nodes_to_build, nodes_to_sub, gpair);
|
||||
i++;
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx,
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
|
||||
common::Monitor *monitor)
|
||||
: param_{std::move(param)},
|
||||
col_sampler_{std::move(column_sampler)},
|
||||
evaluator_{param_, info, ctx->Threads(), col_sampler_, task},
|
||||
ctx_{ctx},
|
||||
monitor_{monitor} {}
|
||||
|
||||
void UpdateTree(RegTree *p_tree, std::vector<GradientPair> const &gpair, common::Span<float> hess,
|
||||
DMatrix *p_fmat) {
|
||||
p_last_tree_ = p_tree;
|
||||
this->InitData(p_fmat, hess);
|
||||
|
||||
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
|
||||
auto &tree = *p_tree;
|
||||
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
|
||||
bst_node_t num_leaves = 1;
|
||||
auto expand_set = driver.Pop();
|
||||
|
||||
while (!expand_set.empty()) {
|
||||
// candidates that can be further splited.
|
||||
std::vector<CPUExpandEntry> valid_candidates;
|
||||
// candidates that can be applied.
|
||||
std::vector<CPUExpandEntry> applied;
|
||||
for (auto const &candidate : expand_set) {
|
||||
if (!candidate.IsValid(param_, num_leaves)) {
|
||||
continue;
|
||||
}
|
||||
evaluator_.ApplyTreeSplit(candidate, p_tree);
|
||||
applied.push_back(candidate);
|
||||
num_leaves++;
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) {
|
||||
valid_candidates.emplace_back(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
monitor_->Start("UpdatePosition");
|
||||
size_t i = 0;
|
||||
for (auto const &page :
|
||||
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
|
||||
partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree);
|
||||
i++;
|
||||
}
|
||||
monitor_->Stop("UpdatePosition");
|
||||
|
||||
std::vector<CPUExpandEntry> best_splits;
|
||||
if (!valid_candidates.empty()) {
|
||||
this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair, hess);
|
||||
for (auto const &candidate : valid_candidates) {
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||
CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}};
|
||||
CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}};
|
||||
best_splits.push_back(l_best);
|
||||
best_splits.push_back(r_best);
|
||||
}
|
||||
auto const &histograms = histogram_builder_.Histogram();
|
||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||
monitor_->Start("EvaluateSplits");
|
||||
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits);
|
||||
monitor_->Stop("EvaluateSplits");
|
||||
}
|
||||
driver.Push(best_splits.begin(), best_splits.end());
|
||||
expand_set = driver.Pop();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Implementation for the approx tree method. It constructs quantile for every
|
||||
* iteration.
|
||||
*/
|
||||
class GlobalApproxUpdater : public TreeUpdater {
|
||||
TrainParam param_;
|
||||
common::Monitor monitor_;
|
||||
CPUHistMakerTrainParam hist_param_;
|
||||
// specializations for different histogram precision.
|
||||
std::unique_ptr<GloablApproxBuilder<float>> f32_impl_;
|
||||
std::unique_ptr<GloablApproxBuilder<double>> f64_impl_;
|
||||
// pointer to the last DMatrix, used for update prediction cache.
|
||||
DMatrix *cached_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
||||
std::make_shared<common::ColumnSampler>();
|
||||
ObjInfo task_;
|
||||
|
||||
public:
|
||||
explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); }
|
||||
|
||||
void Configure(const Args &args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
hist_param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
void LoadConfig(Json const &in) override {
|
||||
auto const &config = get<Object const>(in);
|
||||
FromJson(config.at("train_param"), &this->param_);
|
||||
FromJson(config.at("hist_param"), &this->hist_param_);
|
||||
}
|
||||
void SaveConfig(Json *p_out) const override {
|
||||
auto &out = *p_out;
|
||||
out["train_param"] = ToJson(param_);
|
||||
out["hist_param"] = ToJson(hist_param_);
|
||||
}
|
||||
|
||||
void InitData(TrainParam const ¶m, HostDeviceVector<GradientPair> *gpair,
|
||||
std::vector<GradientPair> *sampled) {
|
||||
auto const &h_gpair = gpair->HostVector();
|
||||
sampled->resize(h_gpair.size());
|
||||
std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin());
|
||||
auto &rnd = common::GlobalRandom();
|
||||
if (param.subsample != 1.0) {
|
||||
CHECK(param.sampling_method != TrainParam::kGradientBased)
|
||||
<< "Gradient based sampling is not supported for approx tree method.";
|
||||
std::bernoulli_distribution coin_flip(param.subsample);
|
||||
std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) {
|
||||
if (coin_flip(rnd)) {
|
||||
return g;
|
||||
} else {
|
||||
return GradientPair{};
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
char const *Name() const override { return "grow_histmaker"; }
|
||||
|
||||
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *m,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
|
||||
if (hist_param_.single_precision_histogram) {
|
||||
f32_impl_ = std::make_unique<GloablApproxBuilder<float>>(param_, m->Info(), tparam_,
|
||||
column_sampler_, task_, &monitor_);
|
||||
} else {
|
||||
f64_impl_ = std::make_unique<GloablApproxBuilder<double>>(param_, m->Info(), tparam_,
|
||||
column_sampler_, task_, &monitor_);
|
||||
}
|
||||
|
||||
std::vector<GradientPair> h_gpair;
|
||||
InitData(param_, gpair, &h_gpair);
|
||||
// Obtain the hessian values for weighted sketching
|
||||
std::vector<float> hess(h_gpair.size());
|
||||
std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(),
|
||||
[](auto g) { return g.GetHess(); });
|
||||
|
||||
cached_ = m;
|
||||
|
||||
for (auto p_tree : trees) {
|
||||
if (hist_param_.single_precision_histogram) {
|
||||
this->f32_impl_->UpdateTree(p_tree, h_gpair, hess, m);
|
||||
} else {
|
||||
this->f64_impl_->UpdateTree(p_tree, h_gpair, hess, m);
|
||||
}
|
||||
}
|
||||
param_.learning_rate = lr;
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
|
||||
if (data != cached_ || (!this->f32_impl_ && !this->f64_impl_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (hist_param_.single_precision_histogram) {
|
||||
this->f32_impl_->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
this->f64_impl_->UpdatePredictionCache(data, out_preds);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(grow_histmaker);
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
|
||||
.describe(
|
||||
"Tree constructor that uses approximate histogram construction "
|
||||
"for each node.")
|
||||
.set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); });
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
@ -641,126 +641,10 @@ class CQHistMaker: public HistMaker {
|
||||
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs_;
|
||||
};
|
||||
|
||||
// global proposal
|
||||
class GlobalProposalHistMaker: public CQHistMaker {
|
||||
public:
|
||||
char const* Name() const override {
|
||||
return "grow_histmaker";
|
||||
}
|
||||
|
||||
protected:
|
||||
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<bst_feature_t> &fset,
|
||||
const RegTree &tree) override {
|
||||
if (this->qexpand_.size() == 1) {
|
||||
cached_rptr_.clear();
|
||||
cached_cut_.clear();
|
||||
}
|
||||
if (cached_rptr_.size() == 0) {
|
||||
CHECK_EQ(this->qexpand_.size(), 1U);
|
||||
CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree);
|
||||
cached_rptr_ = this->wspace_.rptr;
|
||||
cached_cut_ = this->wspace_.cut;
|
||||
} else {
|
||||
this->wspace_.cut.clear();
|
||||
this->wspace_.rptr.clear();
|
||||
this->wspace_.rptr.push_back(0);
|
||||
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
|
||||
for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
|
||||
this->wspace_.rptr.push_back(
|
||||
this->wspace_.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
|
||||
}
|
||||
this->wspace_.cut.insert(this->wspace_.cut.end(), cached_cut_.begin(), cached_cut_.end());
|
||||
}
|
||||
CHECK_EQ(this->wspace_.rptr.size(),
|
||||
(fset.size() + 1) * this->qexpand_.size() + 1);
|
||||
CHECK_EQ(this->wspace_.rptr.back(), this->wspace_.cut.size());
|
||||
}
|
||||
}
|
||||
|
||||
// code to create histogram
|
||||
void CreateHist(const std::vector<GradientPair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<bst_feature_t> &fset,
|
||||
const RegTree &tree) override {
|
||||
const MetaInfo &info = p_fmat->Info();
|
||||
// fill in reverse map
|
||||
this->feat2workindex_.resize(tree.param.num_feature);
|
||||
this->work_set_ = fset;
|
||||
std::fill(this->feat2workindex_.begin(), this->feat2workindex_.end(), -1);
|
||||
for (size_t i = 0; i < fset.size(); ++i) {
|
||||
this->feat2workindex_[fset[i]] = static_cast<int>(i);
|
||||
}
|
||||
// start to work
|
||||
this->wspace_.Configure(1);
|
||||
// to gain speedup in recovery
|
||||
{
|
||||
this->thread_hist_.resize(omp_get_max_threads());
|
||||
|
||||
// TWOPASS: use the real set + split set in the column iteration.
|
||||
this->SetDefaultPostion(p_fmat, tree);
|
||||
this->work_set_.insert(this->work_set_.end(), this->fsplit_set_.begin(),
|
||||
this->fsplit_set_.end());
|
||||
XGBOOST_PARALLEL_SORT(this->work_set_.begin(), this->work_set_.end(),
|
||||
std::less<>{});
|
||||
this->work_set_.resize(
|
||||
std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());
|
||||
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
// TWOPASS: use the real set + split set in the column iteration.
|
||||
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
|
||||
auto page = batch.GetView();
|
||||
|
||||
// start enumeration
|
||||
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
exc.Run([&]() {
|
||||
int fid = this->work_set_[i];
|
||||
int offset = this->feat2workindex_[fid];
|
||||
if (offset >= 0) {
|
||||
this->UpdateHistCol(gpair, page[fid], info, tree,
|
||||
fset, offset,
|
||||
&this->thread_hist_[omp_get_thread_num()]);
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
}
|
||||
|
||||
// update node statistics.
|
||||
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||
&(this->thread_stats_), &(this->node_stats_));
|
||||
for (const int nid : this->qexpand_) {
|
||||
const int wid = this->node2workindex_[nid];
|
||||
this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
|
||||
.data[0] = this->node_stats_[nid];
|
||||
}
|
||||
}
|
||||
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
|
||||
this->wspace_.hset[0].data.size());
|
||||
}
|
||||
|
||||
// cached unit pointer
|
||||
std::vector<unsigned> cached_rptr_;
|
||||
// cached cut value.
|
||||
std::vector<bst_float> cached_cut_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
||||
.describe("Tree constructor that uses approximate histogram construction.")
|
||||
.set_body([](ObjInfo) {
|
||||
return new CQHistMaker();
|
||||
});
|
||||
|
||||
// The updater for approx tree method.
|
||||
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
|
||||
.describe("Tree constructor that uses approximate global of histogram construction.")
|
||||
.set_body([](ObjInfo) {
|
||||
return new GlobalProposalHistMaker();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -35,7 +35,7 @@ TEST(GBTree, SelectTreeMethod) {
|
||||
gbtree.Configure(args);
|
||||
auto const& tparam = gbtree.GetTrainParam();
|
||||
gbtree.Configure({{"tree_method", "approx"}});
|
||||
ASSERT_EQ(tparam.updater_seq, "grow_histmaker,prune");
|
||||
ASSERT_EQ(tparam.updater_seq, "grow_histmaker");
|
||||
gbtree.Configure({{"tree_method", "exact"}});
|
||||
ASSERT_EQ(tparam.updater_seq, "grow_colmaker,prune");
|
||||
gbtree.Configure({{"tree_method", "hist"}});
|
||||
|
||||
@ -72,5 +72,58 @@ TEST(Approx, Partitioner) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Approx, PredictionCache) {
|
||||
size_t n_samples = 2048, n_features = 13;
|
||||
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||
|
||||
{
|
||||
omp_set_num_threads(1);
|
||||
GenericParameter ctx;
|
||||
ctx.InitAllowUnknown(Args{{"nthread", "8"}});
|
||||
std::unique_ptr<TreeUpdater> approx{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
RegTree tree;
|
||||
std::vector<RegTree *> trees{&tree};
|
||||
auto gpair = GenerateRandomGradients(n_samples);
|
||||
approx->Configure(Args{{"max_bin", "64"}});
|
||||
approx->Update(&gpair, Xy.get(), trees);
|
||||
HostDeviceVector<float> out_prediction_cached;
|
||||
out_prediction_cached.Resize(n_samples);
|
||||
auto cache = linalg::VectorView<float>{
|
||||
out_prediction_cached.HostSpan(), {out_prediction_cached.Size()}, GenericParameter::kCpuId};
|
||||
ASSERT_TRUE(approx->UpdatePredictionCache(Xy.get(), cache));
|
||||
}
|
||||
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
learner->SetParam("tree_method", "approx");
|
||||
learner->SetParam("nthread", "0");
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
learner->UpdateOneIter(i, Xy);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> out_prediction_cached;
|
||||
learner->Predict(Xy, false, &out_prediction_cached, 0, 0);
|
||||
|
||||
Json model{Object()};
|
||||
learner->SaveModel(&model);
|
||||
|
||||
HostDeviceVector<float> out_prediction;
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
learner->LoadModel(model);
|
||||
learner->Predict(Xy, false, &out_prediction, 0, 0);
|
||||
}
|
||||
|
||||
auto const h_predt_cached = out_prediction_cached.ConstHostSpan();
|
||||
auto const h_predt = out_prediction.ConstHostSpan();
|
||||
|
||||
ASSERT_EQ(h_predt.size(), h_predt_cached.size());
|
||||
for (size_t i = 0; i < h_predt.size(); ++i) {
|
||||
ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps);
|
||||
}
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -315,57 +315,6 @@ TEST(GpuHist, TestHistogramIndex) {
|
||||
TestHistogramIndexImpl();
|
||||
}
|
||||
|
||||
// gamma is an alias of min_split_loss
|
||||
int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPair>* gpair) {
|
||||
Args args {
|
||||
{"max_depth", "1"},
|
||||
{"max_leaves", "0"},
|
||||
|
||||
// Disable all other parameters.
|
||||
{"colsample_bynode", "1"},
|
||||
{"colsample_bylevel", "1"},
|
||||
{"colsample_bytree", "1"},
|
||||
{"min_child_weight", "0.01"},
|
||||
{"reg_alpha", "0"},
|
||||
{"reg_lambda", "0"},
|
||||
{"max_delta_step", "0"},
|
||||
|
||||
// test gamma
|
||||
{"gamma", std::to_string(gamma)}
|
||||
};
|
||||
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
hist_maker.Configure(args, &generic_param);
|
||||
|
||||
RegTree tree;
|
||||
hist_maker.Update(gpair, dmat, {&tree});
|
||||
|
||||
auto n_nodes = tree.NumExtraNodes();
|
||||
return n_nodes;
|
||||
}
|
||||
|
||||
TEST(GpuHist, MinSplitLoss) {
|
||||
constexpr size_t kRows = 32;
|
||||
constexpr size_t kCols = 16;
|
||||
constexpr float kSparsity = 0.6;
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix();
|
||||
auto gpair = GenerateRandomGradients(kRows);
|
||||
|
||||
{
|
||||
int32_t n_nodes = TestMinSplitLoss(dmat.get(), 0.01, &gpair);
|
||||
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
|
||||
// when writing this test, and only used for testing larger gamma (below) does prevent
|
||||
// building tree.
|
||||
ASSERT_EQ(n_nodes, 2);
|
||||
}
|
||||
{
|
||||
int32_t n_nodes = TestMinSplitLoss(dmat.get(), 100.0, &gpair);
|
||||
// No new nodes with gamma == 100.
|
||||
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
size_t gpu_page_size, RegTree* tree,
|
||||
HostDeviceVector<bst_float>* preds, float subsample = 1.0f,
|
||||
|
||||
@ -61,7 +61,7 @@ class TestGrowPolicy : public ::testing::Test {
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestGrowPolicy, DISABLED_Approx) {
|
||||
TEST_F(TestGrowPolicy, Approx) {
|
||||
this->TestTreeGrowPolicy("approx", "depthwise");
|
||||
this->TestTreeGrowPolicy("approx", "lossguide");
|
||||
}
|
||||
|
||||
@ -114,4 +114,70 @@ TEST_F(UpdaterEtaTest, Approx) { this->RunTest("grow_histmaker"); }
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_F(UpdaterEtaTest, GpuHist) { this->RunTest("grow_gpu_hist"); }
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
class TestMinSplitLoss : public ::testing::Test {
|
||||
std::shared_ptr<DMatrix> dmat_;
|
||||
HostDeviceVector<GradientPair> gpair_;
|
||||
|
||||
void SetUp() override {
|
||||
constexpr size_t kRows = 32;
|
||||
constexpr size_t kCols = 16;
|
||||
constexpr float kSparsity = 0.6;
|
||||
dmat_ = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix();
|
||||
gpair_ = GenerateRandomGradients(kRows);
|
||||
}
|
||||
|
||||
int32_t Update(std::string updater, float gamma) {
|
||||
Args args{{"max_depth", "1"},
|
||||
{"max_leaves", "0"},
|
||||
|
||||
// Disable all other parameters.
|
||||
{"colsample_bynode", "1"},
|
||||
{"colsample_bylevel", "1"},
|
||||
{"colsample_bytree", "1"},
|
||||
{"min_child_weight", "0.01"},
|
||||
{"reg_alpha", "0"},
|
||||
{"reg_lambda", "0"},
|
||||
{"max_delta_step", "0"},
|
||||
|
||||
// test gamma
|
||||
{"gamma", std::to_string(gamma)}};
|
||||
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
auto up = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &generic_param, ObjInfo{ObjInfo::kRegression})};
|
||||
up->Configure(args);
|
||||
|
||||
RegTree tree;
|
||||
up->Update(&gpair_, dmat_.get(), {&tree});
|
||||
|
||||
auto n_nodes = tree.NumExtraNodes();
|
||||
return n_nodes;
|
||||
}
|
||||
|
||||
public:
|
||||
void RunTest(std::string updater) {
|
||||
{
|
||||
int32_t n_nodes = Update(updater, 0.01);
|
||||
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
|
||||
// when writing this test, and only used for testing larger gamma (below) does prevent
|
||||
// building tree.
|
||||
ASSERT_EQ(n_nodes, 2);
|
||||
}
|
||||
{
|
||||
int32_t n_nodes = Update(updater, 100.0);
|
||||
// No new nodes with gamma == 100.
|
||||
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/* Exact tree method requires a pruner as an additional updater, so not tested here. */
|
||||
|
||||
TEST_F(TestMinSplitLoss, Approx) { this->RunTest("grow_histmaker"); }
|
||||
|
||||
TEST_F(TestMinSplitLoss, Hist) { this->RunTest("grow_quantile_histmaker"); }
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_F(TestMinSplitLoss, GpuHist) { this->RunTest("grow_gpu_hist"); }
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost
|
||||
|
||||
@ -7,6 +7,8 @@ from hypothesis import given, strategies, assume, settings, note
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
import test_updaters as test_up
|
||||
|
||||
|
||||
parameter_strategy = strategies.fixed_dictionaries({
|
||||
'max_depth': strategies.integers(0, 11),
|
||||
@ -32,6 +34,8 @@ def train_result(param, dmat, num_rounds):
|
||||
|
||||
|
||||
class TestGPUUpdaters:
|
||||
cputest = test_up.TestTreeMethod()
|
||||
|
||||
@given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy)
|
||||
@settings(deadline=None)
|
||||
def test_gpu_hist(self, param, num_rounds, dataset):
|
||||
@ -41,51 +45,12 @@ class TestGPUUpdaters:
|
||||
note(result)
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats):
|
||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||
|
||||
by_etl_results = {}
|
||||
by_builtin_results = {}
|
||||
|
||||
parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
|
||||
|
||||
m = xgb.DMatrix(onehot, label, enable_categorical=False)
|
||||
xgb.train(
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, "Train")],
|
||||
evals_result=by_etl_results,
|
||||
)
|
||||
|
||||
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
xgb.train(
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, "Train")],
|
||||
evals_result=by_builtin_results,
|
||||
)
|
||||
|
||||
# There are guidelines on how to specify tolerance based on considering output as
|
||||
# random variables. But in here the tree construction is extremely sensitive to
|
||||
# floating point errors. An 1e-5 error in a histogram bin can lead to an entirely
|
||||
# different tree. So even though the test is quite lenient, hypothesis can still
|
||||
# pick up falsifying examples from time to time.
|
||||
np.testing.assert_allclose(
|
||||
np.array(by_etl_results["Train"]["rmse"]),
|
||||
np.array(by_builtin_results["Train"]["rmse"]),
|
||||
rtol=1e-3,
|
||||
)
|
||||
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
|
||||
|
||||
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||
@settings(deadline=None)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
self.run_categorical_basic(rows, cols, rounds, cats)
|
||||
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
||||
|
||||
def test_categorical_32_cat(self):
|
||||
'''32 hits the bound of integer bitset, so special test'''
|
||||
@ -93,7 +58,7 @@ class TestGPUUpdaters:
|
||||
cols = 10
|
||||
cats = 32
|
||||
rounds = 4
|
||||
self.run_categorical_basic(rows, cols, rounds, cats)
|
||||
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
||||
|
||||
def test_invalid_categorical(self):
|
||||
import cupy as cp
|
||||
|
||||
@ -63,7 +63,6 @@ training_dset = xgb.DMatrix(x, label=y)
|
||||
|
||||
|
||||
class TestMonotoneConstraints:
|
||||
|
||||
def test_monotone_constraints_for_exact_tree_method(self):
|
||||
|
||||
# first check monotonicity for the 'exact' tree method
|
||||
@ -76,32 +75,23 @@ class TestMonotoneConstraints:
|
||||
)
|
||||
assert is_correctly_constrained(constrained_exact_method)
|
||||
|
||||
def test_monotone_constraints_for_depthwise_hist_tree_method(self):
|
||||
|
||||
# next check monotonicity for the 'hist' tree method
|
||||
params_for_constrained_hist_method = {
|
||||
'tree_method': 'hist', 'verbosity': 1,
|
||||
'monotone_constraints': '(1, -1)'
|
||||
@pytest.mark.parametrize(
|
||||
"tree_method,policy",
|
||||
[
|
||||
("hist", "depthwise"),
|
||||
("approx", "depthwise"),
|
||||
("hist", "lossguide"),
|
||||
("approx", "lossguide"),
|
||||
],
|
||||
)
|
||||
def test_monotone_constraints(self, tree_method: str, policy: str) -> None:
|
||||
params_for_constrained = {
|
||||
"tree_method": tree_method,
|
||||
"grow_policy": policy,
|
||||
"monotone_constraints": "(1, -1)",
|
||||
}
|
||||
constrained_hist_method = xgb.train(
|
||||
params_for_constrained_hist_method, training_dset
|
||||
)
|
||||
|
||||
assert is_correctly_constrained(constrained_hist_method)
|
||||
|
||||
def test_monotone_constraints_for_lossguide_hist_tree_method(self):
|
||||
|
||||
# next check monotonicity for the 'hist' tree method
|
||||
params_for_constrained_hist_method = {
|
||||
'tree_method': 'hist', 'verbosity': 1,
|
||||
'grow_policy': 'lossguide',
|
||||
'monotone_constraints': '(1, -1)'
|
||||
}
|
||||
constrained_hist_method = xgb.train(
|
||||
params_for_constrained_hist_method, training_dset
|
||||
)
|
||||
|
||||
assert is_correctly_constrained(constrained_hist_method)
|
||||
constrained = xgb.train(params_for_constrained, training_dset)
|
||||
assert is_correctly_constrained(constrained)
|
||||
|
||||
@pytest.mark.parametrize('format', [dict, list])
|
||||
def test_monotone_constraints_feature_names(self, format):
|
||||
|
||||
@ -45,14 +45,20 @@ class TestTreeMethod:
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
assert tm.non_increasing(result['train'][dataset.metric])
|
||||
|
||||
@given(exact_parameter_strategy, strategies.integers(1, 20),
|
||||
tm.dataset_strategy)
|
||||
@given(
|
||||
exact_parameter_strategy,
|
||||
hist_parameter_strategy,
|
||||
strategies.integers(1, 20),
|
||||
tm.dataset_strategy,
|
||||
)
|
||||
@settings(deadline=None)
|
||||
def test_approx(self, param, num_rounds, dataset):
|
||||
param['tree_method'] = 'approx'
|
||||
def test_approx(self, param, hist_param, num_rounds, dataset):
|
||||
param["tree_method"] = "approx"
|
||||
param = dataset.set_params(param)
|
||||
param.update(hist_param)
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
assert tm.non_increasing(result['train'][dataset.metric], 1e-3)
|
||||
note(result)
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_pruner(self):
|
||||
@ -126,3 +132,53 @@ class TestTreeMethod:
|
||||
y = [1000000., 0., 0., 500000.]
|
||||
w = [0, 0, 1, 0]
|
||||
model.fit(X, y, sample_weight=w)
|
||||
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||
|
||||
by_etl_results = {}
|
||||
by_builtin_results = {}
|
||||
|
||||
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
|
||||
# Use one-hot exclusively
|
||||
parameters = {
|
||||
"tree_method": tree_method, "predictor": predictor, "max_cat_to_onehot": 9999
|
||||
}
|
||||
|
||||
m = xgb.DMatrix(onehot, label, enable_categorical=False)
|
||||
xgb.train(
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, "Train")],
|
||||
evals_result=by_etl_results,
|
||||
)
|
||||
|
||||
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
xgb.train(
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, "Train")],
|
||||
evals_result=by_builtin_results,
|
||||
)
|
||||
|
||||
# There are guidelines on how to specify tolerance based on considering output as
|
||||
# random variables. But in here the tree construction is extremely sensitive to
|
||||
# floating point errors. An 1e-5 error in a histogram bin can lead to an entirely
|
||||
# different tree. So even though the test is quite lenient, hypothesis can still
|
||||
# pick up falsifying examples from time to time.
|
||||
np.testing.assert_allclose(
|
||||
np.array(by_etl_results["Train"]["rmse"]),
|
||||
np.array(by_builtin_results["Train"]["rmse"]),
|
||||
rtol=1e-3,
|
||||
)
|
||||
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
|
||||
|
||||
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||
@settings(deadline=None)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
|
||||
|
||||
@ -1184,9 +1184,13 @@ class TestWithDask:
|
||||
for arg in rabit_args:
|
||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||
port_env = arg.decode('utf-8')
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split('=')
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
uri = uri_env.split("=")
|
||||
env["DMLC_TRACKER_URI"] = uri[1]
|
||||
return subprocess.run([str(exe), test], env=env, capture_output=True)
|
||||
|
||||
with LocalCluster(n_workers=4) as cluster:
|
||||
@ -1210,11 +1214,13 @@ class TestWithDask:
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_basic(self) -> None:
|
||||
self.run_quantile('DistributedBasic')
|
||||
self.run_quantile('SortedDistributedBasic')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile(self) -> None:
|
||||
self.run_quantile('Distributed')
|
||||
self.run_quantile('SortedDistributed')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
@ -1252,13 +1258,17 @@ class TestWithDask:
|
||||
for i in range(kCols):
|
||||
fw[i] *= float(i)
|
||||
fw = da.from_array(fw)
|
||||
poly_increasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor)
|
||||
poly_increasing = run_feature_weights(
|
||||
X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor
|
||||
)
|
||||
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
fw[i] *= float(kCols - i)
|
||||
fw = da.from_array(fw)
|
||||
poly_decreasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor)
|
||||
poly_decreasing = run_feature_weights(
|
||||
X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor
|
||||
)
|
||||
|
||||
# Approxmated test, this is dependent on the implementation of random
|
||||
# number generator in std library.
|
||||
|
||||
@ -1031,10 +1031,10 @@ def test_pandas_input():
|
||||
np.array([0, 1]))
|
||||
|
||||
|
||||
def run_feature_weights(X, y, fw, model=xgb.XGBRegressor):
|
||||
def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
colsample_bynode = 0.5
|
||||
reg = model(tree_method='hist', colsample_bynode=colsample_bynode)
|
||||
reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode)
|
||||
|
||||
reg.fit(X, y, feature_weights=fw)
|
||||
model_path = os.path.join(tmpdir, 'model.json')
|
||||
@ -1069,7 +1069,8 @@ def run_feature_weights(X, y, fw, model=xgb.XGBRegressor):
|
||||
return w
|
||||
|
||||
|
||||
def test_feature_weights():
|
||||
@pytest.mark.parametrize("tree_method", ["approx", "hist"])
|
||||
def test_feature_weights(tree_method):
|
||||
kRows = 512
|
||||
kCols = 64
|
||||
X = rng.randn(kRows, kCols)
|
||||
@ -1078,12 +1079,12 @@ def test_feature_weights():
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
fw[i] *= float(i)
|
||||
poly_increasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
|
||||
poly_increasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor)
|
||||
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
fw[i] *= float(kCols - i)
|
||||
poly_decreasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
|
||||
poly_decreasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor)
|
||||
|
||||
# Approxmated test, this is dependent on the implementation of random
|
||||
# number generator in std library.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user