Rewrite approx (#7214)

This PR rewrites the approx tree method to use codebase from hist for better performance and code sharing.

The rewrite has many benefits:
- Support for both `max_leaves` and `max_depth`.
- Support for `grow_policy`.
- Support for mono constraint.
- Support for feature weights.
- Support for easier bin configuration (`max_bin`).
- Support for categorical data.
- Faster performance for most of the datasets. (many times faster)
- Support for prediction cache.
- Significantly better performance for external memory.
- Unites the code base between approx and hist.
This commit is contained in:
Jiaming Yuan 2022-01-10 21:15:05 +08:00 committed by GitHub
parent ed95e77752
commit 001503186c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 635 additions and 264 deletions

View File

@ -57,6 +57,7 @@
#include "../src/tree/updater_refresh.cc" #include "../src/tree/updater_refresh.cc"
#include "../src/tree/updater_sync.cc" #include "../src/tree/updater_sync.cc"
#include "../src/tree/updater_histmaker.cc" #include "../src/tree/updater_histmaker.cc"
#include "../src/tree/updater_approx.cc"
#include "../src/tree/constraints.cc" #include "../src/tree/constraints.cc"
// linear // linear

View File

@ -3,7 +3,8 @@ Getting started with categorical data
===================================== =====================================
Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has
experimental support for one-hot encoding based tree split. experimental support for one-hot encoding based tree split, and in 1.6 `approx` supported
was added.
In before, users need to run an encoder themselves before passing the data into XGBoost, In before, users need to run an encoder themselves before passing the data into XGBoost,
which creates a sparse matrix and potentially increase memory usage. This demo showcases which creates a sparse matrix and potentially increase memory usage. This demo showcases

View File

@ -154,7 +154,7 @@ Parameters for Tree Booster
* ``sketch_eps`` [default=0.03] * ``sketch_eps`` [default=0.03]
- Only used for ``tree_method=approx``. - Only used for ``updater=grow_local_histmaker``.
- This roughly translates into ``O(1 / sketch_eps)`` number of bins. - This roughly translates into ``O(1 / sketch_eps)`` number of bins.
Compared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy. Compared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy.
- Usually user does not have to tune this. - Usually user does not have to tune this.
@ -238,13 +238,27 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other. list is a group of indices of features that are allowed to interact with each other.
See :doc:`/tutorials/feature_interaction_constraint` for more information. See :doc:`/tutorials/feature_interaction_constraint` for more information.
Additional parameters for ``hist`` and ``gpu_hist`` tree method Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
================================================================ ===========================================================================
* ``single_precision_histogram``, [default= ``false``] * ``single_precision_histogram``, [default= ``false``]
- Use single precision to build histograms instead of double precision. - Use single precision to build histograms instead of double precision.
Additional parameters for ``approx`` tree method
================================================
* ``max_cat_to_onehot``
.. versionadded:: 1.6
.. note:: The support for this parameter is experimental.
- A threshold for deciding whether XGBoost should use one-hot encoding based split for
categorical data. When number of categories is lesser than the threshold then one-hot
encoding is chosen, otherwise the categories will be partitioned into children nodes.
Only relevant for regression and binary classification with `approx` tree method.
Additional parameters for Dart Booster (``booster=dart``) Additional parameters for Dart Booster (``booster=dart``)
========================================================= =========================================================

View File

@ -53,7 +53,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
"max_depth" -> "6", "max_depth" -> "6",
"silent" -> "1", "silent" -> "1",
"objective" -> "reg:squarederror", "objective" -> "reg:squarederror",
"max_bin" -> 16, "max_bin" -> 64,
"tree_method" -> treeMethod) "tree_method" -> treeMethod)
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round) val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)

View File

@ -267,6 +267,16 @@ __model_doc = f'''
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds, callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)] save_best=True)]
max_cat_to_onehot : bool
.. versionadded:: 1.6.0
A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold then
one-hot encoding is chosen, otherwise the categories will be partitioned into
children nodes. Only relevant for regression and binary classification and
`approx` tree method.
kwargs : dict, optional kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters Keyword arguments for XGBoost Booster object. Full documentation of parameters
can be found :doc:`here </parameter>`. can be found :doc:`here </parameter>`.
@ -483,6 +493,7 @@ class XGBModel(XGBModelBase):
eval_metric: Optional[Union[str, List[str], Callable]] = None, eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
callbacks: Optional[List[TrainingCallback]] = None, callbacks: Optional[List[TrainingCallback]] = None,
max_cat_to_onehot: Optional[int] = None,
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
@ -522,6 +533,7 @@ class XGBModel(XGBModelBase):
self.eval_metric = eval_metric self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds self.early_stopping_rounds = early_stopping_rounds
self.callbacks = callbacks self.callbacks = callbacks
self.max_cat_to_onehot = max_cat_to_onehot
if kwargs: if kwargs:
self.kwargs = kwargs self.kwargs = kwargs
@ -800,8 +812,8 @@ class XGBModel(XGBModelBase):
_duplicated("callbacks") _duplicated("callbacks")
callbacks = self.callbacks if self.callbacks is not None else callbacks callbacks = self.callbacks if self.callbacks is not None else callbacks
# lastly check categorical data support. tree_method = params.get("tree_method", None)
if self.enable_categorical and params.get("tree_method", None) != "gpu_hist": if self.enable_categorical and tree_method not in ("gpu_hist", "approx"):
raise ValueError( raise ValueError(
"Experimental support for categorical data is not implemented for" "Experimental support for categorical data is not implemented for"
" current tree method yet." " current tree method yet."
@ -876,8 +888,7 @@ class XGBModel(XGBModelBase):
feature_weights : feature_weights :
Weight for each feature, defines the probability of each feature being Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0, selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and otherwise a `ValueError` is thrown.
`exact` tree methods.
callbacks : callbacks :
.. deprecated: 1.6.0 .. deprecated: 1.6.0
@ -1750,8 +1761,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
feature_weights : feature_weights :
Weight for each feature, defines the probability of each feature being Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0, selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and otherwise a `ValueError` is thrown.
`exact` tree methods.
callbacks : callbacks :
.. deprecated: 1.6.0 .. deprecated: 1.6.0

View File

@ -130,8 +130,7 @@ class BlockedSpace2d {
template <typename Func> template <typename Func>
void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) { void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
const size_t num_blocks_in_space = space.Size(); const size_t num_blocks_in_space = space.Size();
nthreads = std::min(nthreads, omp_get_max_threads()); CHECK_GE(nthreads, 1);
nthreads = std::max(nthreads, 1);
dmlc::OMPException exc; dmlc::OMPException exc;
#pragma omp parallel num_threads(nthreads) #pragma omp parallel num_threads(nthreads)
@ -277,9 +276,10 @@ inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
inline int32_t OmpGetNumThreads(int32_t n_threads) { inline int32_t OmpGetNumThreads(int32_t n_threads) {
if (n_threads <= 0) { if (n_threads <= 0) {
n_threads = omp_get_num_procs(); n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
} }
n_threads = std::min(n_threads, OmpGetThreadLimit()); n_threads = std::min(n_threads, OmpGetThreadLimit());
n_threads = std::max(n_threads, 1);
return n_threads; return n_threads;
} }
} // namespace common } // namespace common

View File

@ -168,7 +168,7 @@ void GBTree::ConfigureUpdaters() {
// calling this function. // calling this function.
break; break;
case TreeMethod::kApprox: case TreeMethod::kApprox:
tparam_.updater_seq = "grow_histmaker,prune"; tparam_.updater_seq = "grow_histmaker";
break; break;
case TreeMethod::kExact: case TreeMethod::kExact:
tparam_.updater_seq = "grow_colmaker,prune"; tparam_.updater_seq = "grow_colmaker,prune";

View File

@ -38,7 +38,7 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 }; enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
int grow_policy; int grow_policy;
uint32_t max_cat_to_onehot{1}; uint32_t max_cat_to_onehot{4};
//----- the rest parameters are less important ---- //----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child // minimum amount of hessian(weight) allowed in a child

View File

@ -973,6 +973,7 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
} }
size_t size = categories.size() - begin; size_t size = categories.size() - begin;
categories_sizes.emplace_back(static_cast<Integer::Int>(size)); categories_sizes.emplace_back(static_cast<Integer::Int>(size));
CHECK_NE(size, 0);
} }
} }

View File

@ -35,6 +35,7 @@ DMLC_REGISTRY_LINK_TAG(updater_refresh);
DMLC_REGISTRY_LINK_TAG(updater_prune); DMLC_REGISTRY_LINK_TAG(updater_prune);
DMLC_REGISTRY_LINK_TAG(updater_quantile_hist); DMLC_REGISTRY_LINK_TAG(updater_quantile_hist);
DMLC_REGISTRY_LINK_TAG(updater_histmaker); DMLC_REGISTRY_LINK_TAG(updater_histmaker);
DMLC_REGISTRY_LINK_TAG(updater_approx);
DMLC_REGISTRY_LINK_TAG(updater_sync); DMLC_REGISTRY_LINK_TAG(updater_sync);
#ifdef XGBOOST_USE_CUDA #ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist); DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);

369
src/tree/updater_approx.cc Normal file
View File

@ -0,0 +1,369 @@
/*!
* Copyright 2021 XGBoost contributors
*
* \brief Implementation for the approx tree method.
*/
#include "updater_approx.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "../common/random.h"
#include "../data/gradient_index.h"
#include "constraints.h"
#include "driver.h"
#include "hist/evaluate_splits.h"
#include "hist/histogram.h"
#include "hist/param.h"
#include "param.h"
#include "xgboost/base.h"
#include "xgboost/json.h"
#include "xgboost/tree_updater.h"
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_approx);
template <typename GradientSumT>
class GloablApproxBuilder {
protected:
TrainParam param_;
std::shared_ptr<common::ColumnSampler> col_sampler_;
HistEvaluator<GradientSumT, CPUExpandEntry> evaluator_;
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram_builder_;
GenericParameter const *ctx_;
std::vector<ApproxRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
RegTree *p_last_tree_{nullptr};
common::Monitor *monitor_;
size_t n_batches_{0};
// Cache for histogram cuts.
common::HistogramCuts feature_values_;
public:
void InitData(DMatrix *p_fmat, common::Span<float> hess) {
monitor_->Start(__func__);
n_batches_ = 0;
int32_t n_total_bins = 0;
partitioner_.clear();
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, param_.max_bin, hess, true})) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
feature_values_ = page.cut;
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(page.Size(), page.base_rowid);
n_batches_++;
}
histogram_builder_.Reset(n_total_bins,
BatchParam{GenericParameter::kCpuId, param_.max_bin, hess},
ctx_->Threads(), n_batches_, rabit::IsDistributed());
monitor_->Stop(__func__);
}
CPUExpandEntry InitRoot(DMatrix *p_fmat, std::vector<GradientPair> const &gpair,
common::Span<float> hess, RegTree *p_tree) {
monitor_->Start(__func__);
CPUExpandEntry best;
best.nid = RegTree::kRoot;
best.depth = 0;
GradStats root_sum;
for (auto const &g : gpair) {
root_sum.Add(g);
}
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double *>(&root_sum), 2);
std::vector<CPUExpandEntry> nodes{best};
size_t i = 0;
auto space = this->ConstructHistSpace(nodes);
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
{}, gpair);
i++;
}
auto weight = evaluator_.InitRoot(root_sum);
p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess();
p_tree->Stat(RegTree::kRoot).base_weight = weight;
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
auto const &histograms = histogram_builder_.Histogram();
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes);
monitor_->Stop(__func__);
return nodes.front();
}
void UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) {
monitor_->Start(__func__);
// Caching prediction seems redundant for approx tree method, as sketching takes up
// majority of training time.
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
CHECK(p_last_tree_);
size_t n_nodes = p_last_tree_->GetNodes().size();
auto evaluator = evaluator_.Evaluator();
auto const &tree = *p_last_tree_;
auto const &snode = evaluator_.Stats();
for (auto &part : partitioner_) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) {
if (tree[nidx].IsLeaf()) {
const auto rowset = part[nidx];
auto const &stats = snode.at(nidx);
auto leaf_value =
evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate;
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
}
monitor_->Stop(__func__);
}
// Construct a work space for building histogram. Eventually we should move this
// function into histogram builder once hist tree method supports external memory.
common::BlockedSpace2d ConstructHistSpace(
std::vector<CPUExpandEntry> const &nodes_to_build) const {
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
for (auto const &partition : partitioner_) {
size_t k = 0;
for (auto node : nodes_to_build) {
auto n_rows_in_node = partition.Partitions()[node.nid].Size();
partition_size[k] = std::max(partition_size[k], n_rows_in_node);
k++;
}
}
common::BlockedSpace2d space{nodes_to_build.size(),
[&](size_t nidx_in_set) { return partition_size[nidx_in_set]; },
256};
return space;
}
void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree,
std::vector<CPUExpandEntry> const &valid_candidates,
std::vector<GradientPair> const &gpair, common::Span<float> hess) {
monitor_->Start(__func__);
std::vector<CPUExpandEntry> nodes_to_build;
std::vector<CPUExpandEntry> nodes_to_sub;
for (auto const &c : valid_candidates) {
auto left_nidx = (*p_tree)[c.nid].LeftChild();
auto right_nidx = (*p_tree)[c.nid].RightChild();
auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess();
auto build_nidx = left_nidx;
auto subtract_nidx = right_nidx;
if (fewer_right) {
std::swap(build_nidx, subtract_nidx);
}
nodes_to_build.push_back(CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}});
nodes_to_sub.push_back(CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}});
}
size_t i = 0;
auto space = this->ConstructHistSpace(nodes_to_build);
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
nodes_to_build, nodes_to_sub, gpair);
i++;
}
monitor_->Stop(__func__);
}
public:
explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx,
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
common::Monitor *monitor)
: param_{std::move(param)},
col_sampler_{std::move(column_sampler)},
evaluator_{param_, info, ctx->Threads(), col_sampler_, task},
ctx_{ctx},
monitor_{monitor} {}
void UpdateTree(RegTree *p_tree, std::vector<GradientPair> const &gpair, common::Span<float> hess,
DMatrix *p_fmat) {
p_last_tree_ = p_tree;
this->InitData(p_fmat, hess);
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
auto &tree = *p_tree;
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
bst_node_t num_leaves = 1;
auto expand_set = driver.Pop();
while (!expand_set.empty()) {
// candidates that can be further splited.
std::vector<CPUExpandEntry> valid_candidates;
// candidates that can be applied.
std::vector<CPUExpandEntry> applied;
for (auto const &candidate : expand_set) {
if (!candidate.IsValid(param_, num_leaves)) {
continue;
}
evaluator_.ApplyTreeSplit(candidate, p_tree);
applied.push_back(candidate);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) {
valid_candidates.emplace_back(candidate);
}
}
monitor_->Start("UpdatePosition");
size_t i = 0;
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree);
i++;
}
monitor_->Stop("UpdatePosition");
std::vector<CPUExpandEntry> best_splits;
if (!valid_candidates.empty()) {
this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair, hess);
for (auto const &candidate : valid_candidates) {
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}};
CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}};
best_splits.push_back(l_best);
best_splits.push_back(r_best);
}
auto const &histograms = histogram_builder_.Histogram();
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
monitor_->Start("EvaluateSplits");
evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits);
monitor_->Stop("EvaluateSplits");
}
driver.Push(best_splits.begin(), best_splits.end());
expand_set = driver.Pop();
}
}
};
/**
* \brief Implementation for the approx tree method. It constructs quantile for every
* iteration.
*/
class GlobalApproxUpdater : public TreeUpdater {
TrainParam param_;
common::Monitor monitor_;
CPUHistMakerTrainParam hist_param_;
// specializations for different histogram precision.
std::unique_ptr<GloablApproxBuilder<float>> f32_impl_;
std::unique_ptr<GloablApproxBuilder<double>> f64_impl_;
// pointer to the last DMatrix, used for update prediction cache.
DMatrix *cached_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
ObjInfo task_;
public:
explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); }
void Configure(const Args &args) override {
param_.UpdateAllowUnknown(args);
hist_param_.UpdateAllowUnknown(args);
}
void LoadConfig(Json const &in) override {
auto const &config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_);
FromJson(config.at("hist_param"), &this->hist_param_);
}
void SaveConfig(Json *p_out) const override {
auto &out = *p_out;
out["train_param"] = ToJson(param_);
out["hist_param"] = ToJson(hist_param_);
}
void InitData(TrainParam const &param, HostDeviceVector<GradientPair> *gpair,
std::vector<GradientPair> *sampled) {
auto const &h_gpair = gpair->HostVector();
sampled->resize(h_gpair.size());
std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin());
auto &rnd = common::GlobalRandom();
if (param.subsample != 1.0) {
CHECK(param.sampling_method != TrainParam::kGradientBased)
<< "Gradient based sampling is not supported for approx tree method.";
std::bernoulli_distribution coin_flip(param.subsample);
std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) {
if (coin_flip(rnd)) {
return g;
} else {
return GradientPair{};
}
});
}
}
char const *Name() const override { return "grow_histmaker"; }
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *m,
const std::vector<RegTree *> &trees) override {
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
if (hist_param_.single_precision_histogram) {
f32_impl_ = std::make_unique<GloablApproxBuilder<float>>(param_, m->Info(), tparam_,
column_sampler_, task_, &monitor_);
} else {
f64_impl_ = std::make_unique<GloablApproxBuilder<double>>(param_, m->Info(), tparam_,
column_sampler_, task_, &monitor_);
}
std::vector<GradientPair> h_gpair;
InitData(param_, gpair, &h_gpair);
// Obtain the hessian values for weighted sketching
std::vector<float> hess(h_gpair.size());
std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(),
[](auto g) { return g.GetHess(); });
cached_ = m;
for (auto p_tree : trees) {
if (hist_param_.single_precision_histogram) {
this->f32_impl_->UpdateTree(p_tree, h_gpair, hess, m);
} else {
this->f64_impl_->UpdateTree(p_tree, h_gpair, hess, m);
}
}
param_.learning_rate = lr;
}
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
if (data != cached_ || (!this->f32_impl_ && !this->f64_impl_)) {
return false;
}
if (hist_param_.single_precision_histogram) {
this->f32_impl_->UpdatePredictionCache(data, out_preds);
} else {
this->f64_impl_->UpdatePredictionCache(data, out_preds);
}
return true;
}
};
DMLC_REGISTRY_FILE_TAG(grow_histmaker);
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
.describe(
"Tree constructor that uses approximate histogram construction "
"for each node.")
.set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); });
} // namespace tree
} // namespace xgboost

View File

@ -641,126 +641,10 @@ class CQHistMaker: public HistMaker {
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs_; std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs_;
}; };
// global proposal
class GlobalProposalHistMaker: public CQHistMaker {
public:
char const* Name() const override {
return "grow_histmaker";
}
protected:
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_feature_t> &fset,
const RegTree &tree) override {
if (this->qexpand_.size() == 1) {
cached_rptr_.clear();
cached_cut_.clear();
}
if (cached_rptr_.size() == 0) {
CHECK_EQ(this->qexpand_.size(), 1U);
CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree);
cached_rptr_ = this->wspace_.rptr;
cached_cut_ = this->wspace_.cut;
} else {
this->wspace_.cut.clear();
this->wspace_.rptr.clear();
this->wspace_.rptr.push_back(0);
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
this->wspace_.rptr.push_back(
this->wspace_.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
}
this->wspace_.cut.insert(this->wspace_.cut.end(), cached_cut_.begin(), cached_cut_.end());
}
CHECK_EQ(this->wspace_.rptr.size(),
(fset.size() + 1) * this->qexpand_.size() + 1);
CHECK_EQ(this->wspace_.rptr.back(), this->wspace_.cut.size());
}
}
// code to create histogram
void CreateHist(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_feature_t> &fset,
const RegTree &tree) override {
const MetaInfo &info = p_fmat->Info();
// fill in reverse map
this->feat2workindex_.resize(tree.param.num_feature);
this->work_set_ = fset;
std::fill(this->feat2workindex_.begin(), this->feat2workindex_.end(), -1);
for (size_t i = 0; i < fset.size(); ++i) {
this->feat2workindex_[fset[i]] = static_cast<int>(i);
}
// start to work
this->wspace_.Configure(1);
// to gain speedup in recovery
{
this->thread_hist_.resize(omp_get_max_threads());
// TWOPASS: use the real set + split set in the column iteration.
this->SetDefaultPostion(p_fmat, tree);
this->work_set_.insert(this->work_set_.end(), this->fsplit_set_.begin(),
this->fsplit_set_.end());
XGBOOST_PARALLEL_SORT(this->work_set_.begin(), this->work_set_.end(),
std::less<>{});
this->work_set_.resize(
std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());
// start accumulating statistics
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
auto page = batch.GetView();
// start enumeration
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
dmlc::OMPException exc;
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
exc.Run([&]() {
int fid = this->work_set_[i];
int offset = this->feat2workindex_[fid];
if (offset >= 0) {
this->UpdateHistCol(gpair, page[fid], info, tree,
fset, offset,
&this->thread_hist_[omp_get_thread_num()]);
}
});
}
exc.Rethrow();
}
// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&(this->thread_stats_), &(this->node_stats_));
for (const int nid : this->qexpand_) {
const int wid = this->node2workindex_[nid];
this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
.data[0] = this->node_stats_[nid];
}
}
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
this->wspace_.hset[0].data.size());
}
// cached unit pointer
std::vector<unsigned> cached_rptr_;
// cached cut value.
std::vector<bst_float> cached_cut_;
};
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
.describe("Tree constructor that uses approximate histogram construction.") .describe("Tree constructor that uses approximate histogram construction.")
.set_body([](ObjInfo) { .set_body([](ObjInfo) {
return new CQHistMaker(); return new CQHistMaker();
}); });
// The updater for approx tree method.
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
.describe("Tree constructor that uses approximate global of histogram construction.")
.set_body([](ObjInfo) {
return new GlobalProposalHistMaker();
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -35,7 +35,7 @@ TEST(GBTree, SelectTreeMethod) {
gbtree.Configure(args); gbtree.Configure(args);
auto const& tparam = gbtree.GetTrainParam(); auto const& tparam = gbtree.GetTrainParam();
gbtree.Configure({{"tree_method", "approx"}}); gbtree.Configure({{"tree_method", "approx"}});
ASSERT_EQ(tparam.updater_seq, "grow_histmaker,prune"); ASSERT_EQ(tparam.updater_seq, "grow_histmaker");
gbtree.Configure({{"tree_method", "exact"}}); gbtree.Configure({{"tree_method", "exact"}});
ASSERT_EQ(tparam.updater_seq, "grow_colmaker,prune"); ASSERT_EQ(tparam.updater_seq, "grow_colmaker,prune");
gbtree.Configure({{"tree_method", "hist"}}); gbtree.Configure({{"tree_method", "hist"}});

View File

@ -72,5 +72,58 @@ TEST(Approx, Partitioner) {
} }
} }
} }
TEST(Approx, PredictionCache) {
size_t n_samples = 2048, n_features = 13;
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
{
omp_set_num_threads(1);
GenericParameter ctx;
ctx.InitAllowUnknown(Args{{"nthread", "8"}});
std::unique_ptr<TreeUpdater> approx{
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
RegTree tree;
std::vector<RegTree *> trees{&tree};
auto gpair = GenerateRandomGradients(n_samples);
approx->Configure(Args{{"max_bin", "64"}});
approx->Update(&gpair, Xy.get(), trees);
HostDeviceVector<float> out_prediction_cached;
out_prediction_cached.Resize(n_samples);
auto cache = linalg::VectorView<float>{
out_prediction_cached.HostSpan(), {out_prediction_cached.Size()}, GenericParameter::kCpuId};
ASSERT_TRUE(approx->UpdatePredictionCache(Xy.get(), cache));
}
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->SetParam("tree_method", "approx");
learner->SetParam("nthread", "0");
learner->Configure();
for (size_t i = 0; i < 8; ++i) {
learner->UpdateOneIter(i, Xy);
}
HostDeviceVector<float> out_prediction_cached;
learner->Predict(Xy, false, &out_prediction_cached, 0, 0);
Json model{Object()};
learner->SaveModel(&model);
HostDeviceVector<float> out_prediction;
{
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->LoadModel(model);
learner->Predict(Xy, false, &out_prediction, 0, 0);
}
auto const h_predt_cached = out_prediction_cached.ConstHostSpan();
auto const h_predt = out_prediction.ConstHostSpan();
ASSERT_EQ(h_predt.size(), h_predt_cached.size());
for (size_t i = 0; i < h_predt.size(); ++i) {
ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps);
}
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -315,57 +315,6 @@ TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl(); TestHistogramIndexImpl();
} }
// gamma is an alias of min_split_loss
int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPair>* gpair) {
Args args {
{"max_depth", "1"},
{"max_leaves", "0"},
// Disable all other parameters.
{"colsample_bynode", "1"},
{"colsample_bylevel", "1"},
{"colsample_bytree", "1"},
{"min_child_weight", "0.01"},
{"reg_alpha", "0"},
{"reg_lambda", "0"},
{"max_delta_step", "0"},
// test gamma
{"gamma", std::to_string(gamma)}
};
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
GenericParameter generic_param(CreateEmptyGenericParam(0));
hist_maker.Configure(args, &generic_param);
RegTree tree;
hist_maker.Update(gpair, dmat, {&tree});
auto n_nodes = tree.NumExtraNodes();
return n_nodes;
}
TEST(GpuHist, MinSplitLoss) {
constexpr size_t kRows = 32;
constexpr size_t kCols = 16;
constexpr float kSparsity = 0.6;
auto dmat = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix();
auto gpair = GenerateRandomGradients(kRows);
{
int32_t n_nodes = TestMinSplitLoss(dmat.get(), 0.01, &gpair);
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
// when writing this test, and only used for testing larger gamma (below) does prevent
// building tree.
ASSERT_EQ(n_nodes, 2);
}
{
int32_t n_nodes = TestMinSplitLoss(dmat.get(), 100.0, &gpair);
// No new nodes with gamma == 100.
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
}
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
size_t gpu_page_size, RegTree* tree, size_t gpu_page_size, RegTree* tree,
HostDeviceVector<bst_float>* preds, float subsample = 1.0f, HostDeviceVector<bst_float>* preds, float subsample = 1.0f,

View File

@ -61,7 +61,7 @@ class TestGrowPolicy : public ::testing::Test {
} }
}; };
TEST_F(TestGrowPolicy, DISABLED_Approx) { TEST_F(TestGrowPolicy, Approx) {
this->TestTreeGrowPolicy("approx", "depthwise"); this->TestTreeGrowPolicy("approx", "depthwise");
this->TestTreeGrowPolicy("approx", "lossguide"); this->TestTreeGrowPolicy("approx", "lossguide");
} }

View File

@ -114,4 +114,70 @@ TEST_F(UpdaterEtaTest, Approx) { this->RunTest("grow_histmaker"); }
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
TEST_F(UpdaterEtaTest, GpuHist) { this->RunTest("grow_gpu_hist"); } TEST_F(UpdaterEtaTest, GpuHist) { this->RunTest("grow_gpu_hist"); }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
class TestMinSplitLoss : public ::testing::Test {
std::shared_ptr<DMatrix> dmat_;
HostDeviceVector<GradientPair> gpair_;
void SetUp() override {
constexpr size_t kRows = 32;
constexpr size_t kCols = 16;
constexpr float kSparsity = 0.6;
dmat_ = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix();
gpair_ = GenerateRandomGradients(kRows);
}
int32_t Update(std::string updater, float gamma) {
Args args{{"max_depth", "1"},
{"max_leaves", "0"},
// Disable all other parameters.
{"colsample_bynode", "1"},
{"colsample_bylevel", "1"},
{"colsample_bytree", "1"},
{"min_child_weight", "0.01"},
{"reg_alpha", "0"},
{"reg_lambda", "0"},
{"max_delta_step", "0"},
// test gamma
{"gamma", std::to_string(gamma)}};
GenericParameter generic_param(CreateEmptyGenericParam(0));
auto up = std::unique_ptr<TreeUpdater>{
TreeUpdater::Create(updater, &generic_param, ObjInfo{ObjInfo::kRegression})};
up->Configure(args);
RegTree tree;
up->Update(&gpair_, dmat_.get(), {&tree});
auto n_nodes = tree.NumExtraNodes();
return n_nodes;
}
public:
void RunTest(std::string updater) {
{
int32_t n_nodes = Update(updater, 0.01);
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
// when writing this test, and only used for testing larger gamma (below) does prevent
// building tree.
ASSERT_EQ(n_nodes, 2);
}
{
int32_t n_nodes = Update(updater, 100.0);
// No new nodes with gamma == 100.
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
}
}
};
/* Exact tree method requires a pruner as an additional updater, so not tested here. */
TEST_F(TestMinSplitLoss, Approx) { this->RunTest("grow_histmaker"); }
TEST_F(TestMinSplitLoss, Hist) { this->RunTest("grow_quantile_histmaker"); }
#if defined(XGBOOST_USE_CUDA)
TEST_F(TestMinSplitLoss, GpuHist) { this->RunTest("grow_gpu_hist"); }
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost } // namespace xgboost

View File

@ -7,6 +7,8 @@ from hypothesis import given, strategies, assume, settings, note
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm import testing as tm
import test_updaters as test_up
parameter_strategy = strategies.fixed_dictionaries({ parameter_strategy = strategies.fixed_dictionaries({
'max_depth': strategies.integers(0, 11), 'max_depth': strategies.integers(0, 11),
@ -32,6 +34,8 @@ def train_result(param, dmat, num_rounds):
class TestGPUUpdaters: class TestGPUUpdaters:
cputest = test_up.TestTreeMethod()
@given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) @given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy)
@settings(deadline=None) @settings(deadline=None)
def test_gpu_hist(self, param, num_rounds, dataset): def test_gpu_hist(self, param, num_rounds, dataset):
@ -41,51 +45,12 @@ class TestGPUUpdaters:
note(result) note(result)
assert tm.non_increasing(result["train"][dataset.metric]) assert tm.non_increasing(result["train"][dataset.metric])
def run_categorical_basic(self, rows, cols, rounds, cats):
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
by_etl_results = {}
by_builtin_results = {}
parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)
# There are guidelines on how to specify tolerance based on considering output as
# random variables. But in here the tree construction is extremely sensitive to
# floating point errors. An 1e-5 error in a histogram bin can lead to an entirely
# different tree. So even though the test is quite lenient, hypothesis can still
# pick up falsifying examples from time to time.
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
@given(strategies.integers(10, 400), strategies.integers(3, 8), @given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7)) strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None) @settings(deadline=None)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats): def test_categorical(self, rows, cols, rounds, cats):
self.run_categorical_basic(rows, cols, rounds, cats) self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
def test_categorical_32_cat(self): def test_categorical_32_cat(self):
'''32 hits the bound of integer bitset, so special test''' '''32 hits the bound of integer bitset, so special test'''
@ -93,7 +58,7 @@ class TestGPUUpdaters:
cols = 10 cols = 10
cats = 32 cats = 32
rounds = 4 rounds = 4
self.run_categorical_basic(rows, cols, rounds, cats) self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
def test_invalid_categorical(self): def test_invalid_categorical(self):
import cupy as cp import cupy as cp

View File

@ -63,7 +63,6 @@ training_dset = xgb.DMatrix(x, label=y)
class TestMonotoneConstraints: class TestMonotoneConstraints:
def test_monotone_constraints_for_exact_tree_method(self): def test_monotone_constraints_for_exact_tree_method(self):
# first check monotonicity for the 'exact' tree method # first check monotonicity for the 'exact' tree method
@ -76,32 +75,23 @@ class TestMonotoneConstraints:
) )
assert is_correctly_constrained(constrained_exact_method) assert is_correctly_constrained(constrained_exact_method)
def test_monotone_constraints_for_depthwise_hist_tree_method(self): @pytest.mark.parametrize(
"tree_method,policy",
# next check monotonicity for the 'hist' tree method [
params_for_constrained_hist_method = { ("hist", "depthwise"),
'tree_method': 'hist', 'verbosity': 1, ("approx", "depthwise"),
'monotone_constraints': '(1, -1)' ("hist", "lossguide"),
} ("approx", "lossguide"),
constrained_hist_method = xgb.train( ],
params_for_constrained_hist_method, training_dset
) )
def test_monotone_constraints(self, tree_method: str, policy: str) -> None:
assert is_correctly_constrained(constrained_hist_method) params_for_constrained = {
"tree_method": tree_method,
def test_monotone_constraints_for_lossguide_hist_tree_method(self): "grow_policy": policy,
"monotone_constraints": "(1, -1)",
# next check monotonicity for the 'hist' tree method
params_for_constrained_hist_method = {
'tree_method': 'hist', 'verbosity': 1,
'grow_policy': 'lossguide',
'monotone_constraints': '(1, -1)'
} }
constrained_hist_method = xgb.train( constrained = xgb.train(params_for_constrained, training_dset)
params_for_constrained_hist_method, training_dset assert is_correctly_constrained(constrained)
)
assert is_correctly_constrained(constrained_hist_method)
@pytest.mark.parametrize('format', [dict, list]) @pytest.mark.parametrize('format', [dict, list])
def test_monotone_constraints_feature_names(self, format): def test_monotone_constraints_feature_names(self, format):

View File

@ -45,14 +45,20 @@ class TestTreeMethod:
result = train_result(param, dataset.get_dmat(), num_rounds) result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result['train'][dataset.metric]) assert tm.non_increasing(result['train'][dataset.metric])
@given(exact_parameter_strategy, strategies.integers(1, 20), @given(
tm.dataset_strategy) exact_parameter_strategy,
hist_parameter_strategy,
strategies.integers(1, 20),
tm.dataset_strategy,
)
@settings(deadline=None) @settings(deadline=None)
def test_approx(self, param, num_rounds, dataset): def test_approx(self, param, hist_param, num_rounds, dataset):
param['tree_method'] = 'approx' param["tree_method"] = "approx"
param = dataset.set_params(param) param = dataset.set_params(param)
param.update(hist_param)
result = train_result(param, dataset.get_dmat(), num_rounds) result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result['train'][dataset.metric], 1e-3) note(result)
assert tm.non_increasing(result["train"][dataset.metric])
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_pruner(self): def test_pruner(self):
@ -126,3 +132,53 @@ class TestTreeMethod:
y = [1000000., 0., 0., 500000.] y = [1000000., 0., 0., 500000.]
w = [0, 0, 1, 0] w = [0, 0, 1, 0]
model.fit(X, y, sample_weight=w) model.fit(X, y, sample_weight=w)
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
by_etl_results = {}
by_builtin_results = {}
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
# Use one-hot exclusively
parameters = {
"tree_method": tree_method, "predictor": predictor, "max_cat_to_onehot": 9999
}
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)
# There are guidelines on how to specify tolerance based on considering output as
# random variables. But in here the tree construction is extremely sensitive to
# floating point errors. An 1e-5 error in a histogram bin can lead to an entirely
# different tree. So even though the test is quite lenient, hypothesis can still
# pick up falsifying examples from time to time.
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats):
self.run_categorical_basic(rows, cols, rounds, cats, "approx")

View File

@ -1184,9 +1184,13 @@ class TestWithDask:
for arg in rabit_args: for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port_env = arg.decode('utf-8') port_env = arg.decode('utf-8')
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8")
port = port_env.split('=') port = port_env.split('=')
env = os.environ.copy() env = os.environ.copy()
env[port[0]] = port[1] env[port[0]] = port[1]
uri = uri_env.split("=")
env["DMLC_TRACKER_URI"] = uri[1]
return subprocess.run([str(exe), test], env=env, capture_output=True) return subprocess.run([str(exe), test], env=env, capture_output=True)
with LocalCluster(n_workers=4) as cluster: with LocalCluster(n_workers=4) as cluster:
@ -1210,11 +1214,13 @@ class TestWithDask:
@pytest.mark.gtest @pytest.mark.gtest
def test_quantile_basic(self) -> None: def test_quantile_basic(self) -> None:
self.run_quantile('DistributedBasic') self.run_quantile('DistributedBasic')
self.run_quantile('SortedDistributedBasic')
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest @pytest.mark.gtest
def test_quantile(self) -> None: def test_quantile(self) -> None:
self.run_quantile('Distributed') self.run_quantile('Distributed')
self.run_quantile('SortedDistributed')
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest @pytest.mark.gtest
@ -1252,13 +1258,17 @@ class TestWithDask:
for i in range(kCols): for i in range(kCols):
fw[i] *= float(i) fw[i] *= float(i)
fw = da.from_array(fw) fw = da.from_array(fw)
poly_increasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor) poly_increasing = run_feature_weights(
X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor
)
fw = np.ones(shape=(kCols,)) fw = np.ones(shape=(kCols,))
for i in range(kCols): for i in range(kCols):
fw[i] *= float(kCols - i) fw[i] *= float(kCols - i)
fw = da.from_array(fw) fw = da.from_array(fw)
poly_decreasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor) poly_decreasing = run_feature_weights(
X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor
)
# Approxmated test, this is dependent on the implementation of random # Approxmated test, this is dependent on the implementation of random
# number generator in std library. # number generator in std library.

View File

@ -1031,10 +1031,10 @@ def test_pandas_input():
np.array([0, 1])) np.array([0, 1]))
def run_feature_weights(X, y, fw, model=xgb.XGBRegressor): def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
colsample_bynode = 0.5 colsample_bynode = 0.5
reg = model(tree_method='hist', colsample_bynode=colsample_bynode) reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode)
reg.fit(X, y, feature_weights=fw) reg.fit(X, y, feature_weights=fw)
model_path = os.path.join(tmpdir, 'model.json') model_path = os.path.join(tmpdir, 'model.json')
@ -1069,7 +1069,8 @@ def run_feature_weights(X, y, fw, model=xgb.XGBRegressor):
return w return w
def test_feature_weights(): @pytest.mark.parametrize("tree_method", ["approx", "hist"])
def test_feature_weights(tree_method):
kRows = 512 kRows = 512
kCols = 64 kCols = 64
X = rng.randn(kRows, kCols) X = rng.randn(kRows, kCols)
@ -1078,12 +1079,12 @@ def test_feature_weights():
fw = np.ones(shape=(kCols,)) fw = np.ones(shape=(kCols,))
for i in range(kCols): for i in range(kCols):
fw[i] *= float(i) fw[i] *= float(i)
poly_increasing = run_feature_weights(X, y, fw, xgb.XGBRegressor) poly_increasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor)
fw = np.ones(shape=(kCols,)) fw = np.ones(shape=(kCols,))
for i in range(kCols): for i in range(kCols):
fw[i] *= float(kCols - i) fw[i] *= float(kCols - i)
poly_decreasing = run_feature_weights(X, y, fw, xgb.XGBRegressor) poly_decreasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor)
# Approxmated test, this is dependent on the implementation of random # Approxmated test, this is dependent on the implementation of random
# number generator in std library. # number generator in std library.