Pass obj info by reference instead of by value. (#8889)
- Pass obj info into tree updater as const pointer. This way we don't have to initialize the learner model param before configuring gbm, hence breaking up the dependency of configurations.
This commit is contained in:
@@ -340,7 +340,7 @@ void GBTree::InitUpdater(Args const& cfg) {
|
||||
// create new updaters
|
||||
for (const std::string& pstr : ups) {
|
||||
std::unique_ptr<TreeUpdater> up(
|
||||
TreeUpdater::Create(pstr.c_str(), ctx_, model_.learner_model_param->task));
|
||||
TreeUpdater::Create(pstr.c_str(), ctx_, &model_.learner_model_param->task));
|
||||
up->Configure(cfg);
|
||||
updaters_.push_back(std::move(up));
|
||||
}
|
||||
@@ -448,7 +448,7 @@ void GBTree::LoadConfig(Json const& in) {
|
||||
LOG(WARNING) << "Changing updater from `grow_gpu_hist` to `grow_quantile_histmaker`.";
|
||||
}
|
||||
std::unique_ptr<TreeUpdater> up{
|
||||
TreeUpdater::Create(name, ctx_, model_.learner_model_param->task)};
|
||||
TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task)};
|
||||
up->LoadConfig(kv.second);
|
||||
updaters_.push_back(std::move(up));
|
||||
}
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file tree_updater.cc
|
||||
* \brief Registry of tree updaters.
|
||||
*/
|
||||
#include "xgboost/tree_updater.h"
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include <string> // for string
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo task) {
|
||||
TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo const* task) {
|
||||
auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown tree updater " << name;
|
||||
@@ -22,11 +22,9 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, Ob
|
||||
auto p_updater = (e->body)(ctx, task);
|
||||
return p_updater;
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace xgboost::tree {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(updater_colmaker);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_refresh);
|
||||
@@ -37,5 +35,4 @@ DMLC_REGISTRY_LINK_TAG(updater_sync);
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -14,14 +14,15 @@
|
||||
#include "driver.h"
|
||||
#include "hist/evaluate_splits.h"
|
||||
#include "hist/histogram.h"
|
||||
#include "hist/sampler.h" // SampleGradient
|
||||
#include "hist/sampler.h" // for SampleGradient
|
||||
#include "param.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/tree_updater.h" // for TreeUpdater
|
||||
|
||||
namespace xgboost::tree {
|
||||
|
||||
@@ -40,12 +41,12 @@ auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
|
||||
|
||||
class GloablApproxBuilder {
|
||||
protected:
|
||||
TrainParam const* param_;
|
||||
TrainParam const *param_;
|
||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||
HistEvaluator<CPUExpandEntry> evaluator_;
|
||||
HistogramBuilder<CPUExpandEntry> histogram_builder_;
|
||||
Context const *ctx_;
|
||||
ObjInfo const task_;
|
||||
ObjInfo const *const task_;
|
||||
|
||||
std::vector<CommonRowPartitioner> partitioner_;
|
||||
// Pointer to last updated tree, used for update prediction cache.
|
||||
@@ -63,7 +64,8 @@ class GloablApproxBuilder {
|
||||
bst_bin_t n_total_bins = 0;
|
||||
partitioner_.clear();
|
||||
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, task_))) {
|
||||
for (auto const &page :
|
||||
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, *task_))) {
|
||||
if (n_total_bins == 0) {
|
||||
n_total_bins = page.cut.TotalBins();
|
||||
feature_values_ = page.cut;
|
||||
@@ -157,7 +159,7 @@ class GloablApproxBuilder {
|
||||
void LeafPartition(RegTree const &tree, common::Span<float const> hess,
|
||||
std::vector<bst_node_t> *p_out_position) {
|
||||
monitor_->Start(__func__);
|
||||
if (!task_.UpdateTreeLeaf()) {
|
||||
if (!task_->UpdateTreeLeaf()) {
|
||||
return;
|
||||
}
|
||||
for (auto const &part : partitioner_) {
|
||||
@@ -168,8 +170,8 @@ class GloablApproxBuilder {
|
||||
|
||||
public:
|
||||
explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx,
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
|
||||
common::Monitor *monitor)
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||
ObjInfo const *task, common::Monitor *monitor)
|
||||
: param_{param},
|
||||
col_sampler_{std::move(column_sampler)},
|
||||
evaluator_{ctx, param_, info, col_sampler_},
|
||||
@@ -256,10 +258,11 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
DMatrix *cached_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
||||
std::make_shared<common::ColumnSampler>();
|
||||
ObjInfo task_;
|
||||
ObjInfo const *task_;
|
||||
|
||||
public:
|
||||
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {
|
||||
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo const *task)
|
||||
: TreeUpdater(ctx), task_{task} {
|
||||
monitor_.Init(__func__);
|
||||
}
|
||||
|
||||
@@ -317,5 +320,7 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
|
||||
.describe(
|
||||
"Tree constructor that uses approximate histogram construction "
|
||||
"for each node.")
|
||||
.set_body([](Context const *ctx, ObjInfo task) { return new GlobalApproxUpdater(ctx, task); });
|
||||
.set_body([](Context const *ctx, ObjInfo const *task) {
|
||||
return new GlobalApproxUpdater(ctx, task);
|
||||
});
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -603,5 +603,5 @@ class ColMaker: public TreeUpdater {
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
||||
.describe("Grow tree with parallelization over columns.")
|
||||
.set_body([](Context const *ctx, ObjInfo) { return new ColMaker(ctx); });
|
||||
.set_body([](Context const *ctx, auto) { return new ColMaker(ctx); });
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -15,12 +15,12 @@
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "constraints.cuh"
|
||||
#include "driver.h"
|
||||
#include "gpu_hist/evaluate_splits.cuh"
|
||||
@@ -39,11 +39,10 @@
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/task.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace xgboost::tree {
|
||||
#if !defined(GTEST_TEST)
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||
#endif // !defined(GTEST_TEST)
|
||||
@@ -106,12 +105,12 @@ class DeviceHistogramStorage {
|
||||
nidx_map_.clear();
|
||||
overflow_nidx_map_.clear();
|
||||
}
|
||||
bool HistogramExists(int nidx) const {
|
||||
[[nodiscard]] bool HistogramExists(int nidx) const {
|
||||
return nidx_map_.find(nidx) != nidx_map_.cend() ||
|
||||
overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
|
||||
}
|
||||
int Bins() const { return n_bins_; }
|
||||
size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
|
||||
[[nodiscard]] int Bins() const { return n_bins_; }
|
||||
[[nodiscard]] size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
|
||||
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }
|
||||
|
||||
void AllocateHistograms(const std::vector<int>& new_nidxs) {
|
||||
@@ -690,8 +689,9 @@ struct GPUHistMakerDevice {
|
||||
return root_entry;
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
|
||||
RegTree* p_tree, collective::DeviceCommunicator* communicator,
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
ObjInfo const* task, RegTree* p_tree,
|
||||
collective::DeviceCommunicator* communicator,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto& tree = *p_tree;
|
||||
// Process maximum 32 nodes at a time
|
||||
@@ -741,7 +741,7 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
monitor.Start("FinalisePosition");
|
||||
this->FinalisePosition(p_tree, p_fmat, task, p_out_position);
|
||||
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
|
||||
monitor.Stop("FinalisePosition");
|
||||
}
|
||||
};
|
||||
@@ -750,7 +750,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
using GradientSumT = GradientPairPrecise;
|
||||
|
||||
public:
|
||||
explicit GPUHistMaker(Context const* ctx, ObjInfo task)
|
||||
explicit GPUHistMaker(Context const* ctx, ObjInfo const* task)
|
||||
: TreeUpdater(ctx), task_{task} {};
|
||||
void Configure(const Args& args) override {
|
||||
// Used in test to count how many configurations are performed
|
||||
@@ -872,8 +872,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||
|
||||
char const* Name() const override { return "grow_gpu_hist"; }
|
||||
bool HasNodePosition() const override { return true; }
|
||||
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
|
||||
[[nodiscard]] bool HasNodePosition() const override { return true; }
|
||||
|
||||
private:
|
||||
bool initialised_{false};
|
||||
@@ -882,7 +882,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
DMatrix* p_last_fmat_{nullptr};
|
||||
RegTree const* p_last_tree_{nullptr};
|
||||
ObjInfo task_;
|
||||
ObjInfo const* task_{nullptr};
|
||||
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
@@ -890,8 +890,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
#if !defined(GTEST_TEST)
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||
.describe("Grow tree with GPU.")
|
||||
.set_body([](Context const* ctx, ObjInfo task) { return new GPUHistMaker(ctx, task); });
|
||||
.set_body([](Context const* ctx, ObjInfo const* task) {
|
||||
return new GPUHistMaker(ctx, task);
|
||||
});
|
||||
#endif // !defined(GTEST_TEST)
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -18,7 +18,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune);
|
||||
/*! \brief pruner that prunes a tree after growing finishes */
|
||||
class TreePruner : public TreeUpdater {
|
||||
public:
|
||||
explicit TreePruner(Context const* ctx, ObjInfo task) : TreeUpdater(ctx) {
|
||||
explicit TreePruner(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx) {
|
||||
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
|
||||
pruner_monitor_.Init("TreePruner");
|
||||
}
|
||||
@@ -90,5 +90,7 @@ class TreePruner : public TreeUpdater {
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
||||
.describe("Pruner that prune the tree according to statistics.")
|
||||
.set_body([](Context const* ctx, ObjInfo task) { return new TreePruner(ctx, task); });
|
||||
.set_body([](Context const* ctx, ObjInfo const* task) {
|
||||
return new TreePruner{ctx, task};
|
||||
});
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -35,7 +35,7 @@ void QuantileHistMaker::Update(TrainParam const *param, HostDeviceVector<Gradien
|
||||
// build tree
|
||||
const size_t n_trees = trees.size();
|
||||
if (!pimpl_) {
|
||||
pimpl_.reset(new Builder(n_trees, param, dmat, task_, ctx_));
|
||||
pimpl_.reset(new Builder(n_trees, param, dmat, *task_, ctx_));
|
||||
}
|
||||
|
||||
size_t t_idx{0};
|
||||
@@ -287,6 +287,8 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
||||
.describe("Grow tree using quantized histogram.")
|
||||
.set_body([](Context const *ctx, ObjInfo task) { return new QuantileHistMaker(ctx, task); });
|
||||
.set_body([](Context const *ctx, ObjInfo const *task) {
|
||||
return new QuantileHistMaker(ctx, task);
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -43,7 +43,8 @@ inline BatchParam HistBatch(TrainParam const* param) {
|
||||
/*! \brief construct a tree using quantized feature values */
|
||||
class QuantileHistMaker: public TreeUpdater {
|
||||
public:
|
||||
explicit QuantileHistMaker(Context const* ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {}
|
||||
explicit QuantileHistMaker(Context const* ctx, ObjInfo const* task)
|
||||
: TreeUpdater(ctx), task_{task} {}
|
||||
void Configure(const Args&) override {}
|
||||
|
||||
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
@@ -125,7 +126,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
|
||||
protected:
|
||||
std::unique_ptr<Builder> pimpl_;
|
||||
ObjInfo task_;
|
||||
ObjInfo const* task_;
|
||||
};
|
||||
} // namespace xgboost::tree
|
||||
|
||||
|
||||
@@ -142,5 +142,5 @@ class TreeRefresher : public TreeUpdater {
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
||||
.describe("Refresher that refreshes the weight and statistics according to data.")
|
||||
.set_body([](Context const *ctx, ObjInfo) { return new TreeRefresher(ctx); });
|
||||
.set_body([](Context const *ctx, auto) { return new TreeRefresher(ctx); });
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2013 by XBGoost Contributors
|
||||
* Copyright 2014-2023 by XBGoost Contributors
|
||||
* \file updater_sync.cc
|
||||
* \brief synchronize the tree in all distributed nodes
|
||||
*/
|
||||
@@ -53,5 +53,5 @@ class TreeSyncher : public TreeUpdater {
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
||||
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
||||
.set_body([](Context const* ctx, ObjInfo) { return new TreeSyncher(ctx); });
|
||||
.set_body([](Context const* ctx, auto) { return new TreeSyncher(ctx); });
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Reference in New Issue
Block a user