Pass obj info by reference instead of by value. (#8889)

- Pass obj info into tree updater as const pointer.

This way we don't have to initialize the learner model param before configuring gbm, hence
breaking up the dependency of configurations.
This commit is contained in:
Jiaming Yuan 2023-03-11 01:38:28 +08:00 committed by GitHub
parent 54e001bbf4
commit 6deaec8027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 125 additions and 112 deletions

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2014-2022 by XGBoost Contributors * Copyright 2014-2023 by XGBoost Contributors
* \file tree_updater.h * \file tree_updater.h
* \brief General primitive for tree learning, * \brief General primitive for tree learning,
* Updating a collection of trees given the information. * Updating a collection of trees given the information.
@ -9,19 +9,17 @@
#define XGBOOST_TREE_UPDATER_H_ #define XGBOOST_TREE_UPDATER_H_
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <xgboost/base.h> #include <xgboost/base.h> // for Args, GradientPair
#include <xgboost/context.h> #include <xgboost/data.h> // DMatrix
#include <xgboost/data.h> #include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/host_device_vector.h> #include <xgboost/linalg.h> // for VectorView
#include <xgboost/linalg.h> #include <xgboost/model.h> // for Configurable
#include <xgboost/model.h> #include <xgboost/span.h> // for Span
#include <xgboost/task.h> #include <xgboost/tree_model.h> // for RegTree
#include <xgboost/tree_model.h>
#include <functional> #include <functional> // for function
#include <string> #include <string> // for string
#include <utility> #include <vector> // for vector
#include <vector>
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -30,8 +28,9 @@ struct TrainParam;
class Json; class Json;
struct Context; struct Context;
struct ObjInfo;
/*! /**
* \brief interface of tree update module, that performs update of a tree. * \brief interface of tree update module, that performs update of a tree.
*/ */
class TreeUpdater : public Configurable { class TreeUpdater : public Configurable {
@ -53,12 +52,12 @@ class TreeUpdater : public Configurable {
* used for modifying existing trees (like `prune`). Return true if it can modify * used for modifying existing trees (like `prune`). Return true if it can modify
* existing trees. * existing trees.
*/ */
virtual bool CanModifyTree() const { return false; } [[nodiscard]] virtual bool CanModifyTree() const { return false; }
/*! /*!
* \brief Wether the out_position in `Update` is valid. This determines whether adaptive * \brief Wether the out_position in `Update` is valid. This determines whether adaptive
* tree can be used. * tree can be used.
*/ */
virtual bool HasNodePosition() const { return false; } [[nodiscard]] virtual bool HasNodePosition() const { return false; }
/** /**
* \brief perform update to the tree models * \brief perform update to the tree models
* *
@ -91,14 +90,15 @@ class TreeUpdater : public Configurable {
return false; return false;
} }
virtual char const* Name() const = 0; [[nodiscard]] virtual char const* Name() const = 0;
/*! /**
* \brief Create a tree updater given name * \brief Create a tree updater given name
* \param name Name of the tree updater. * \param name Name of the tree updater.
* \param ctx A global runtime parameter * \param ctx A global runtime parameter
* \param task Infomation about the objective.
*/ */
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo task); static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
}; };
/*! /*!
@ -106,7 +106,7 @@ class TreeUpdater : public Configurable {
*/ */
struct TreeUpdaterReg struct TreeUpdaterReg
: public dmlc::FunctionRegEntryBase< : public dmlc::FunctionRegEntryBase<
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo task)>> {}; TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};
/*! /*!
* \brief Macro to register tree updater. * \brief Macro to register tree updater.

View File

@ -340,7 +340,7 @@ void GBTree::InitUpdater(Args const& cfg) {
// create new updaters // create new updaters
for (const std::string& pstr : ups) { for (const std::string& pstr : ups) {
std::unique_ptr<TreeUpdater> up( std::unique_ptr<TreeUpdater> up(
TreeUpdater::Create(pstr.c_str(), ctx_, model_.learner_model_param->task)); TreeUpdater::Create(pstr.c_str(), ctx_, &model_.learner_model_param->task));
up->Configure(cfg); up->Configure(cfg);
updaters_.push_back(std::move(up)); updaters_.push_back(std::move(up));
} }
@ -448,7 +448,7 @@ void GBTree::LoadConfig(Json const& in) {
LOG(WARNING) << "Changing updater from `grow_gpu_hist` to `grow_quantile_histmaker`."; LOG(WARNING) << "Changing updater from `grow_gpu_hist` to `grow_quantile_histmaker`.";
} }
std::unique_ptr<TreeUpdater> up{ std::unique_ptr<TreeUpdater> up{
TreeUpdater::Create(name, ctx_, model_.learner_model_param->task)}; TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task)};
up->LoadConfig(kv.second); up->LoadConfig(kv.second);
updaters_.push_back(std::move(up)); updaters_.push_back(std::move(up));
} }

View File

@ -1,20 +1,20 @@
/*! /**
* Copyright 2015-2022 by XGBoost Contributors * Copyright 2015-2023 by XGBoost Contributors
* \file tree_updater.cc * \file tree_updater.cc
* \brief Registry of tree updaters. * \brief Registry of tree updaters.
*/ */
#include "xgboost/tree_updater.h"
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include "xgboost/tree_updater.h" #include <string> // for string
#include "xgboost/host_device_vector.h"
namespace dmlc { namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg); DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
} // namespace dmlc } // namespace dmlc
namespace xgboost { namespace xgboost {
TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo const* task) {
TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo task) {
auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
if (e == nullptr) { if (e == nullptr) {
LOG(FATAL) << "Unknown tree updater " << name; LOG(FATAL) << "Unknown tree updater " << name;
@ -22,11 +22,9 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, Ob
auto p_updater = (e->body)(ctx, task); auto p_updater = (e->body)(ctx, task);
return p_updater; return p_updater;
} }
} // namespace xgboost } // namespace xgboost
namespace xgboost { namespace xgboost::tree {
namespace tree {
// List of files that will be force linked in static links. // List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(updater_colmaker); DMLC_REGISTRY_LINK_TAG(updater_colmaker);
DMLC_REGISTRY_LINK_TAG(updater_refresh); DMLC_REGISTRY_LINK_TAG(updater_refresh);
@ -37,5 +35,4 @@ DMLC_REGISTRY_LINK_TAG(updater_sync);
#ifdef XGBOOST_USE_CUDA #ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist); DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);
#endif // XGBOOST_USE_CUDA #endif // XGBOOST_USE_CUDA
} // namespace tree } // namespace xgboost::tree
} // namespace xgboost

View File

@ -14,14 +14,15 @@
#include "driver.h" #include "driver.h"
#include "hist/evaluate_splits.h" #include "hist/evaluate_splits.h"
#include "hist/histogram.h" #include "hist/histogram.h"
#include "hist/sampler.h" // SampleGradient #include "hist/sampler.h" // for SampleGradient
#include "param.h" #include "param.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/linalg.h" #include "xgboost/linalg.h"
#include "xgboost/task.h" // for ObjInfo
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h" #include "xgboost/tree_updater.h" // for TreeUpdater
namespace xgboost::tree { namespace xgboost::tree {
@ -40,12 +41,12 @@ auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
class GloablApproxBuilder { class GloablApproxBuilder {
protected: protected:
TrainParam const* param_; TrainParam const *param_;
std::shared_ptr<common::ColumnSampler> col_sampler_; std::shared_ptr<common::ColumnSampler> col_sampler_;
HistEvaluator<CPUExpandEntry> evaluator_; HistEvaluator<CPUExpandEntry> evaluator_;
HistogramBuilder<CPUExpandEntry> histogram_builder_; HistogramBuilder<CPUExpandEntry> histogram_builder_;
Context const *ctx_; Context const *ctx_;
ObjInfo const task_; ObjInfo const *const task_;
std::vector<CommonRowPartitioner> partitioner_; std::vector<CommonRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache. // Pointer to last updated tree, used for update prediction cache.
@ -63,7 +64,8 @@ class GloablApproxBuilder {
bst_bin_t n_total_bins = 0; bst_bin_t n_total_bins = 0;
partitioner_.clear(); partitioner_.clear();
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up? // Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, task_))) { for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, *task_))) {
if (n_total_bins == 0) { if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins(); n_total_bins = page.cut.TotalBins();
feature_values_ = page.cut; feature_values_ = page.cut;
@ -157,7 +159,7 @@ class GloablApproxBuilder {
void LeafPartition(RegTree const &tree, common::Span<float const> hess, void LeafPartition(RegTree const &tree, common::Span<float const> hess,
std::vector<bst_node_t> *p_out_position) { std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__); monitor_->Start(__func__);
if (!task_.UpdateTreeLeaf()) { if (!task_->UpdateTreeLeaf()) {
return; return;
} }
for (auto const &part : partitioner_) { for (auto const &part : partitioner_) {
@ -168,8 +170,8 @@ class GloablApproxBuilder {
public: public:
explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx, explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx,
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task, std::shared_ptr<common::ColumnSampler> column_sampler,
common::Monitor *monitor) ObjInfo const *task, common::Monitor *monitor)
: param_{param}, : param_{param},
col_sampler_{std::move(column_sampler)}, col_sampler_{std::move(column_sampler)},
evaluator_{ctx, param_, info, col_sampler_}, evaluator_{ctx, param_, info, col_sampler_},
@ -256,10 +258,11 @@ class GlobalApproxUpdater : public TreeUpdater {
DMatrix *cached_{nullptr}; DMatrix *cached_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ = std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>(); std::make_shared<common::ColumnSampler>();
ObjInfo task_; ObjInfo const *task_;
public: public:
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} { explicit GlobalApproxUpdater(Context const *ctx, ObjInfo const *task)
: TreeUpdater(ctx), task_{task} {
monitor_.Init(__func__); monitor_.Init(__func__);
} }
@ -317,5 +320,7 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
.describe( .describe(
"Tree constructor that uses approximate histogram construction " "Tree constructor that uses approximate histogram construction "
"for each node.") "for each node.")
.set_body([](Context const *ctx, ObjInfo task) { return new GlobalApproxUpdater(ctx, task); }); .set_body([](Context const *ctx, ObjInfo const *task) {
return new GlobalApproxUpdater(ctx, task);
});
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -603,5 +603,5 @@ class ColMaker: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
.describe("Grow tree with parallelization over columns.") .describe("Grow tree with parallelization over columns.")
.set_body([](Context const *ctx, ObjInfo) { return new ColMaker(ctx); }); .set_body([](Context const *ctx, auto) { return new ColMaker(ctx); });
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -15,12 +15,12 @@
#include "../collective/device_communicator.cuh" #include "../collective/device_communicator.cuh"
#include "../common/bitfield.h" #include "../common/bitfield.h"
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/cuda_context.cuh" // CUDAContext
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/io.h" #include "../common/io.h"
#include "../common/timer.h" #include "../common/timer.h"
#include "../data/ellpack_page.cuh" #include "../data/ellpack_page.cuh"
#include "../common/cuda_context.cuh" // CUDAContext
#include "constraints.cuh" #include "constraints.cuh"
#include "driver.h" #include "driver.h"
#include "gpu_hist/evaluate_splits.cuh" #include "gpu_hist/evaluate_splits.cuh"
@ -39,11 +39,10 @@
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/parameter.h" #include "xgboost/parameter.h"
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/task.h" #include "xgboost/task.h" // for ObjInfo
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
namespace xgboost { namespace xgboost::tree {
namespace tree {
#if !defined(GTEST_TEST) #if !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
#endif // !defined(GTEST_TEST) #endif // !defined(GTEST_TEST)
@ -106,12 +105,12 @@ class DeviceHistogramStorage {
nidx_map_.clear(); nidx_map_.clear();
overflow_nidx_map_.clear(); overflow_nidx_map_.clear();
} }
bool HistogramExists(int nidx) const { [[nodiscard]] bool HistogramExists(int nidx) const {
return nidx_map_.find(nidx) != nidx_map_.cend() || return nidx_map_.find(nidx) != nidx_map_.cend() ||
overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend(); overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
} }
int Bins() const { return n_bins_; } [[nodiscard]] int Bins() const { return n_bins_; }
size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } [[nodiscard]] size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; } dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }
void AllocateHistograms(const std::vector<int>& new_nidxs) { void AllocateHistograms(const std::vector<int>& new_nidxs) {
@ -690,8 +689,9 @@ struct GPUHistMakerDevice {
return root_entry; return root_entry;
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task, void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
RegTree* p_tree, collective::DeviceCommunicator* communicator, ObjInfo const* task, RegTree* p_tree,
collective::DeviceCommunicator* communicator,
HostDeviceVector<bst_node_t>* p_out_position) { HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree; auto& tree = *p_tree;
// Process maximum 32 nodes at a time // Process maximum 32 nodes at a time
@ -741,7 +741,7 @@ struct GPUHistMakerDevice {
} }
monitor.Start("FinalisePosition"); monitor.Start("FinalisePosition");
this->FinalisePosition(p_tree, p_fmat, task, p_out_position); this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
monitor.Stop("FinalisePosition"); monitor.Stop("FinalisePosition");
} }
}; };
@ -750,7 +750,7 @@ class GPUHistMaker : public TreeUpdater {
using GradientSumT = GradientPairPrecise; using GradientSumT = GradientPairPrecise;
public: public:
explicit GPUHistMaker(Context const* ctx, ObjInfo task) explicit GPUHistMaker(Context const* ctx, ObjInfo const* task)
: TreeUpdater(ctx), task_{task} {}; : TreeUpdater(ctx), task_{task} {};
void Configure(const Args& args) override { void Configure(const Args& args) override {
// Used in test to count how many configurations are performed // Used in test to count how many configurations are performed
@ -872,8 +872,8 @@ class GPUHistMaker : public TreeUpdater {
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
char const* Name() const override { return "grow_gpu_hist"; } [[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
bool HasNodePosition() const override { return true; } [[nodiscard]] bool HasNodePosition() const override { return true; }
private: private:
bool initialised_{false}; bool initialised_{false};
@ -882,7 +882,7 @@ class GPUHistMaker : public TreeUpdater {
DMatrix* p_last_fmat_{nullptr}; DMatrix* p_last_fmat_{nullptr};
RegTree const* p_last_tree_{nullptr}; RegTree const* p_last_tree_{nullptr};
ObjInfo task_; ObjInfo const* task_{nullptr};
common::Monitor monitor_; common::Monitor monitor_;
}; };
@ -890,8 +890,8 @@ class GPUHistMaker : public TreeUpdater {
#if !defined(GTEST_TEST) #if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.") .describe("Grow tree with GPU.")
.set_body([](Context const* ctx, ObjInfo task) { return new GPUHistMaker(ctx, task); }); .set_body([](Context const* ctx, ObjInfo const* task) {
return new GPUHistMaker(ctx, task);
});
#endif // !defined(GTEST_TEST) #endif // !defined(GTEST_TEST)
} // namespace xgboost::tree
} // namespace tree
} // namespace xgboost

View File

@ -18,7 +18,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune);
/*! \brief pruner that prunes a tree after growing finishes */ /*! \brief pruner that prunes a tree after growing finishes */
class TreePruner : public TreeUpdater { class TreePruner : public TreeUpdater {
public: public:
explicit TreePruner(Context const* ctx, ObjInfo task) : TreeUpdater(ctx) { explicit TreePruner(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx) {
syncher_.reset(TreeUpdater::Create("sync", ctx_, task)); syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
pruner_monitor_.Init("TreePruner"); pruner_monitor_.Init("TreePruner");
} }
@ -90,5 +90,7 @@ class TreePruner : public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune") XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
.describe("Pruner that prune the tree according to statistics.") .describe("Pruner that prune the tree according to statistics.")
.set_body([](Context const* ctx, ObjInfo task) { return new TreePruner(ctx, task); }); .set_body([](Context const* ctx, ObjInfo const* task) {
return new TreePruner{ctx, task};
});
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -35,7 +35,7 @@ void QuantileHistMaker::Update(TrainParam const *param, HostDeviceVector<Gradien
// build tree // build tree
const size_t n_trees = trees.size(); const size_t n_trees = trees.size();
if (!pimpl_) { if (!pimpl_) {
pimpl_.reset(new Builder(n_trees, param, dmat, task_, ctx_)); pimpl_.reset(new Builder(n_trees, param, dmat, *task_, ctx_));
} }
size_t t_idx{0}; size_t t_idx{0};
@ -287,6 +287,8 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.") .describe("Grow tree using quantized histogram.")
.set_body([](Context const *ctx, ObjInfo task) { return new QuantileHistMaker(ctx, task); }); .set_body([](Context const *ctx, ObjInfo const *task) {
return new QuantileHistMaker(ctx, task);
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -43,7 +43,8 @@ inline BatchParam HistBatch(TrainParam const* param) {
/*! \brief construct a tree using quantized feature values */ /*! \brief construct a tree using quantized feature values */
class QuantileHistMaker: public TreeUpdater { class QuantileHistMaker: public TreeUpdater {
public: public:
explicit QuantileHistMaker(Context const* ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {} explicit QuantileHistMaker(Context const* ctx, ObjInfo const* task)
: TreeUpdater(ctx), task_{task} {}
void Configure(const Args&) override {} void Configure(const Args&) override {}
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
@ -125,7 +126,7 @@ class QuantileHistMaker: public TreeUpdater {
protected: protected:
std::unique_ptr<Builder> pimpl_; std::unique_ptr<Builder> pimpl_;
ObjInfo task_; ObjInfo const* task_;
}; };
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -142,5 +142,5 @@ class TreeRefresher : public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
.describe("Refresher that refreshes the weight and statistics according to data.") .describe("Refresher that refreshes the weight and statistics according to data.")
.set_body([](Context const *ctx, ObjInfo) { return new TreeRefresher(ctx); }); .set_body([](Context const *ctx, auto) { return new TreeRefresher(ctx); });
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2014-2013 by XBGoost Contributors * Copyright 2014-2023 by XBGoost Contributors
* \file updater_sync.cc * \file updater_sync.cc
* \brief synchronize the tree in all distributed nodes * \brief synchronize the tree in all distributed nodes
*/ */
@ -53,5 +53,5 @@ class TreeSyncher : public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
.describe("Syncher that synchronize the tree in all distributed nodes.") .describe("Syncher that synchronize the tree in all distributed nodes.")
.set_body([](Context const* ctx, ObjInfo) { return new TreeSyncher(ctx); }); .set_body([](Context const* ctx, auto) { return new TreeSyncher(ctx); });
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -170,8 +170,8 @@ void TestHistogramIndexImpl() {
// Build 2 matrices and build a histogram maker with that // Build 2 matrices and build a histogram maker with that
Context ctx(CreateEmptyGenericParam(0)); Context ctx(CreateEmptyGenericParam(0));
tree::GPUHistMaker hist_maker{&ctx, ObjInfo{ObjInfo::kRegression}}, ObjInfo task{ObjInfo::kRegression};
hist_maker_ext{&ctx, ObjInfo{ObjInfo::kRegression}}; tree::GPUHistMaker hist_maker{&ctx, &task}, hist_maker_ext{&ctx, &task};
std::unique_ptr<DMatrix> hist_maker_dmat( std::unique_ptr<DMatrix> hist_maker_dmat(
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
@ -240,7 +240,8 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
param.UpdateAllowUnknown(args); param.UpdateAllowUnknown(args);
Context ctx(CreateEmptyGenericParam(0)); Context ctx(CreateEmptyGenericParam(0));
tree::GPUHistMaker hist_maker{&ctx,ObjInfo{ObjInfo::kRegression}}; ObjInfo task{ObjInfo::kRegression};
tree::GPUHistMaker hist_maker{&ctx, &task};
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
hist_maker.Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, hist_maker.Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
@ -385,8 +386,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
TEST(GpuHist, ConfigIO) { TEST(GpuHist, ConfigIO) {
Context ctx(CreateEmptyGenericParam(0)); Context ctx(CreateEmptyGenericParam(0));
std::unique_ptr<TreeUpdater> updater{ ObjInfo task{ObjInfo::kRegression};
TreeUpdater::Create("grow_gpu_hist", &ctx, ObjInfo{ObjInfo::kRegression})}; std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)};
updater->Configure(Args{}); updater->Configure(Args{});
Json j_updater { Object() }; Json j_updater { Object() };

View File

@ -37,13 +37,13 @@ TEST(GrowHistMaker, InteractionConstraint)
auto p_gradients = GenerateGradients(kRows); auto p_gradients = GenerateGradients(kRows);
Context ctx; Context ctx;
ObjInfo task{ObjInfo::kRegression};
{ {
// With constraints // With constraints
RegTree tree; RegTree tree;
tree.param.num_feature = kCols; tree.param.num_feature = kCols;
std::unique_ptr<TreeUpdater> updater{ std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
TrainParam param; TrainParam param;
param.UpdateAllowUnknown( param.UpdateAllowUnknown(
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
@ -61,8 +61,7 @@ TEST(GrowHistMaker, InteractionConstraint)
RegTree tree; RegTree tree;
tree.param.num_feature = kCols; tree.param.num_feature = kCols;
std::unique_ptr<TreeUpdater> updater{ std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param; TrainParam param;
param.Init(Args{}); param.Init(Args{});
@ -81,8 +80,8 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
auto p_dmat = GenerateDMatrix(rows, cols); auto p_dmat = GenerateDMatrix(rows, cols);
auto p_gradients = GenerateGradients(rows); auto p_gradients = GenerateGradients(rows);
Context ctx; Context ctx;
std::unique_ptr<TreeUpdater> updater{ ObjInfo task{ObjInfo::kRegression};
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})}; std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
std::unique_ptr<DMatrix> sliced{ std::unique_ptr<DMatrix> sliced{
@ -110,12 +109,12 @@ TEST(GrowHistMaker, ColumnSplit) {
RegTree expected_tree; RegTree expected_tree;
expected_tree.param.num_feature = kCols; expected_tree.param.num_feature = kCols;
ObjInfo task{ObjInfo::kRegression};
{ {
auto p_dmat = GenerateDMatrix(kRows, kCols); auto p_dmat = GenerateDMatrix(kRows, kCols);
auto p_gradients = GenerateGradients(kRows); auto p_gradients = GenerateGradients(kRows);
Context ctx; Context ctx;
std::unique_ptr<TreeUpdater> updater{ std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param; TrainParam param;
param.Init(Args{}); param.Init(Args{});

View File

@ -2,22 +2,25 @@
* Copyright 2023 by XGBoost contributors * Copyright 2023 by XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/task.h> #include <xgboost/context.h> // for Context
#include <xgboost/tree_updater.h> #include <xgboost/task.h> // for ObjInfo
#include <xgboost/tree_updater.h> // for TreeUpdater
#include <memory> // for unique_ptr
namespace xgboost { namespace xgboost {
TEST(Updater, HasNodePosition) { TEST(Updater, HasNodePosition) {
Context ctx; Context ctx;
ObjInfo task{ObjInfo::kRegression, true, true}; ObjInfo task{ObjInfo::kRegression, true, true};
std::unique_ptr<TreeUpdater> up{TreeUpdater::Create("grow_histmaker", &ctx, task)}; std::unique_ptr<TreeUpdater> up{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
ASSERT_TRUE(up->HasNodePosition()); ASSERT_TRUE(up->HasNodePosition());
up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, task)); up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task));
ASSERT_TRUE(up->HasNodePosition()); ASSERT_TRUE(up->HasNodePosition());
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
ctx.gpu_id = 0; ctx.gpu_id = 0;
up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, task)); up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, &task));
ASSERT_TRUE(up->HasNodePosition()); ASSERT_TRUE(up->HasNodePosition());
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} }

View File

@ -9,6 +9,7 @@
#include "../../../src/tree/param.h" // for TrainParam #include "../../../src/tree/param.h" // for TrainParam
#include "../helpers.h" #include "../helpers.h"
#include "xgboost/task.h" // for ObjInfo
namespace xgboost { namespace xgboost {
@ -71,8 +72,8 @@ class TestPredictionCache : public ::testing::Test {
ctx.gpu_id = Context::kCpuId; ctx.gpu_id = Context::kCpuId;
} }
std::unique_ptr<TreeUpdater> updater{ ObjInfo task{ObjInfo::kRegression};
TreeUpdater::Create(updater_name, &ctx, ObjInfo{ObjInfo::kRegression})}; std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, &ctx, &task)};
RegTree tree; RegTree tree;
std::vector<RegTree *> trees{&tree}; std::vector<RegTree *> trees{&tree};
auto gpair = GenerateRandomGradients(n_samples_); auto gpair = GenerateRandomGradients(n_samples_);

View File

@ -39,8 +39,8 @@ TEST(Updater, Prune) {
TrainParam param; TrainParam param;
param.UpdateAllowUnknown(cfg); param.UpdateAllowUnknown(cfg);
std::unique_ptr<TreeUpdater> pruner( ObjInfo task{ObjInfo::kRegression};
TreeUpdater::Create("prune", &ctx, ObjInfo{ObjInfo::kRegression})); std::unique_ptr<TreeUpdater> pruner(TreeUpdater::Create("prune", &ctx, &task));
// loss_chg < min_split_loss; // loss_chg < min_split_loss;
std::vector<HostDeviceVector<bst_node_t>> position(trees.size()); std::vector<HostDeviceVector<bst_node_t>> position(trees.size());

View File

@ -1,8 +1,9 @@
/** /**
* Copyright 2018-2013 by XGBoost Contributors * Copyright 2018-2023 by XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/task.h> // for ObjInfo
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <memory> #include <memory>
@ -12,9 +13,7 @@
#include "../../../src/tree/param.h" // for TrainParam #include "../../../src/tree/param.h" // for TrainParam
#include "../helpers.h" #include "../helpers.h"
namespace xgboost { namespace xgboost::tree {
namespace tree {
TEST(Updater, Refresh) { TEST(Updater, Refresh) {
bst_row_t constexpr kRows = 8; bst_row_t constexpr kRows = 8;
bst_feature_t constexpr kCols = 16; bst_feature_t constexpr kCols = 16;
@ -33,8 +32,9 @@ TEST(Updater, Refresh) {
auto ctx = CreateEmptyGenericParam(GPUIDX); auto ctx = CreateEmptyGenericParam(GPUIDX);
tree.param.UpdateAllowUnknown(cfg); tree.param.UpdateAllowUnknown(cfg);
std::vector<RegTree*> trees{&tree}; std::vector<RegTree*> trees{&tree};
std::unique_ptr<TreeUpdater> refresher(
TreeUpdater::Create("refresh", &ctx, ObjInfo{ObjInfo::kRegression})); ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &ctx, &task));
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f, tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f); /*left_sum=*/0.0f, /*right_sum=*/0.0f);
@ -57,6 +57,4 @@ TEST(Updater, Refresh) {
ASSERT_NEAR(0, tree.Stat(1).loss_chg, kEps); ASSERT_NEAR(0, tree.Stat(1).loss_chg, kEps);
ASSERT_NEAR(0, tree.Stat(2).loss_chg, kEps); ASSERT_NEAR(0, tree.Stat(2).loss_chg, kEps);
} }
} // namespace xgboost::tree
} // namespace tree
} // namespace xgboost

View File

@ -2,9 +2,13 @@
* Copyright 2020-2023 by XGBoost Contributors * Copyright 2020-2023 by XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/context.h> // for Context
#include <xgboost/task.h> // for ObjInfo
#include <xgboost/tree_model.h> #include <xgboost/tree_model.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <memory> // for unique_ptr
#include "../../../src/tree/param.h" // for TrainParam #include "../../../src/tree/param.h" // for TrainParam
#include "../helpers.h" #include "../helpers.h"
@ -26,12 +30,12 @@ class UpdaterTreeStatTest : public ::testing::Test {
void RunTest(std::string updater) { void RunTest(std::string updater) {
tree::TrainParam param; tree::TrainParam param;
ObjInfo task{ObjInfo::kRegression};
param.Init(Args{}); param.Init(Args{});
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0) Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
: CreateEmptyGenericParam(Context::kCpuId)); : CreateEmptyGenericParam(Context::kCpuId));
auto up = std::unique_ptr<TreeUpdater>{ auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
up->Configure(Args{}); up->Configure(Args{});
RegTree tree; RegTree tree;
tree.param.num_feature = kCols; tree.param.num_feature = kCols;
@ -74,18 +78,18 @@ class UpdaterEtaTest : public ::testing::Test {
} }
void RunTest(std::string updater) { void RunTest(std::string updater) {
ObjInfo task{ObjInfo::kClassification};
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0) Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
: CreateEmptyGenericParam(Context::kCpuId)); : CreateEmptyGenericParam(Context::kCpuId));
float eta = 0.4; float eta = 0.4;
auto up_0 = std::unique_ptr<TreeUpdater>{ auto up_0 = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
up_0->Configure(Args{}); up_0->Configure(Args{});
tree::TrainParam param0; tree::TrainParam param0;
param0.Init(Args{{"eta", std::to_string(eta)}}); param0.Init(Args{{"eta", std::to_string(eta)}});
auto up_1 = std::unique_ptr<TreeUpdater>{ auto up_1 = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
up_1->Configure(Args{{"eta", "1.0"}}); up_1->Configure(Args{{"eta", "1.0"}});
tree::TrainParam param1; tree::TrainParam param1;
param1.Init(Args{{"eta", "1.0"}}); param1.Init(Args{{"eta", "1.0"}});
@ -153,11 +157,11 @@ class TestMinSplitLoss : public ::testing::Test {
{"gamma", std::to_string(gamma)}}; {"gamma", std::to_string(gamma)}};
tree::TrainParam param; tree::TrainParam param;
param.UpdateAllowUnknown(args); param.UpdateAllowUnknown(args);
ObjInfo task{ObjInfo::kRegression};
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0) Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
: CreateEmptyGenericParam(Context::kCpuId)); : CreateEmptyGenericParam(Context::kCpuId));
auto up = std::unique_ptr<TreeUpdater>{ auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
up->Configure({}); up->Configure({});
RegTree tree; RegTree tree;