Pass obj info by reference instead of by value. (#8889)
- Pass obj info into tree updater as const pointer. This way we don't have to initialize the learner model param before configuring gbm, hence breaking up the dependency of configurations.
This commit is contained in:
parent
54e001bbf4
commit
6deaec8027
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
* \file tree_updater.h
|
||||
* \brief General primitive for tree learning,
|
||||
* Updating a collection of trees given the information.
|
||||
@ -9,19 +9,17 @@
|
||||
#define XGBOOST_TREE_UPDATER_H_
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/task.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/base.h> // for Args, GradientPair
|
||||
#include <xgboost/data.h> // DMatrix
|
||||
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||
#include <xgboost/linalg.h> // for VectorView
|
||||
#include <xgboost/model.h> // for Configurable
|
||||
#include <xgboost/span.h> // for Span
|
||||
#include <xgboost/tree_model.h> // for RegTree
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional> // for function
|
||||
#include <string> // for string
|
||||
#include <vector> // for vector
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -30,8 +28,9 @@ struct TrainParam;
|
||||
|
||||
class Json;
|
||||
struct Context;
|
||||
struct ObjInfo;
|
||||
|
||||
/*!
|
||||
/**
|
||||
* \brief interface of tree update module, that performs update of a tree.
|
||||
*/
|
||||
class TreeUpdater : public Configurable {
|
||||
@ -53,12 +52,12 @@ class TreeUpdater : public Configurable {
|
||||
* used for modifying existing trees (like `prune`). Return true if it can modify
|
||||
* existing trees.
|
||||
*/
|
||||
virtual bool CanModifyTree() const { return false; }
|
||||
[[nodiscard]] virtual bool CanModifyTree() const { return false; }
|
||||
/*!
|
||||
* \brief Wether the out_position in `Update` is valid. This determines whether adaptive
|
||||
* tree can be used.
|
||||
*/
|
||||
virtual bool HasNodePosition() const { return false; }
|
||||
[[nodiscard]] virtual bool HasNodePosition() const { return false; }
|
||||
/**
|
||||
* \brief perform update to the tree models
|
||||
*
|
||||
@ -91,14 +90,15 @@ class TreeUpdater : public Configurable {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual char const* Name() const = 0;
|
||||
[[nodiscard]] virtual char const* Name() const = 0;
|
||||
|
||||
/*!
|
||||
/**
|
||||
* \brief Create a tree updater given name
|
||||
* \param name Name of the tree updater.
|
||||
* \param ctx A global runtime parameter
|
||||
* \param task Infomation about the objective.
|
||||
*/
|
||||
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo task);
|
||||
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
|
||||
};
|
||||
|
||||
/*!
|
||||
@ -106,7 +106,7 @@ class TreeUpdater : public Configurable {
|
||||
*/
|
||||
struct TreeUpdaterReg
|
||||
: public dmlc::FunctionRegEntryBase<
|
||||
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo task)>> {};
|
||||
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register tree updater.
|
||||
|
||||
@ -340,7 +340,7 @@ void GBTree::InitUpdater(Args const& cfg) {
|
||||
// create new updaters
|
||||
for (const std::string& pstr : ups) {
|
||||
std::unique_ptr<TreeUpdater> up(
|
||||
TreeUpdater::Create(pstr.c_str(), ctx_, model_.learner_model_param->task));
|
||||
TreeUpdater::Create(pstr.c_str(), ctx_, &model_.learner_model_param->task));
|
||||
up->Configure(cfg);
|
||||
updaters_.push_back(std::move(up));
|
||||
}
|
||||
@ -448,7 +448,7 @@ void GBTree::LoadConfig(Json const& in) {
|
||||
LOG(WARNING) << "Changing updater from `grow_gpu_hist` to `grow_quantile_histmaker`.";
|
||||
}
|
||||
std::unique_ptr<TreeUpdater> up{
|
||||
TreeUpdater::Create(name, ctx_, model_.learner_model_param->task)};
|
||||
TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task)};
|
||||
up->LoadConfig(kv.second);
|
||||
updaters_.push_back(std::move(up));
|
||||
}
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file tree_updater.cc
|
||||
* \brief Registry of tree updaters.
|
||||
*/
|
||||
#include "xgboost/tree_updater.h"
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include <string> // for string
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo task) {
|
||||
TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo const* task) {
|
||||
auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown tree updater " << name;
|
||||
@ -22,11 +22,9 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, Ob
|
||||
auto p_updater = (e->body)(ctx, task);
|
||||
return p_updater;
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace xgboost::tree {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(updater_colmaker);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_refresh);
|
||||
@ -37,5 +35,4 @@ DMLC_REGISTRY_LINK_TAG(updater_sync);
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -14,14 +14,15 @@
|
||||
#include "driver.h"
|
||||
#include "hist/evaluate_splits.h"
|
||||
#include "hist/histogram.h"
|
||||
#include "hist/sampler.h" // SampleGradient
|
||||
#include "hist/sampler.h" // for SampleGradient
|
||||
#include "param.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/tree_updater.h" // for TreeUpdater
|
||||
|
||||
namespace xgboost::tree {
|
||||
|
||||
@ -45,7 +46,7 @@ class GloablApproxBuilder {
|
||||
HistEvaluator<CPUExpandEntry> evaluator_;
|
||||
HistogramBuilder<CPUExpandEntry> histogram_builder_;
|
||||
Context const *ctx_;
|
||||
ObjInfo const task_;
|
||||
ObjInfo const *const task_;
|
||||
|
||||
std::vector<CommonRowPartitioner> partitioner_;
|
||||
// Pointer to last updated tree, used for update prediction cache.
|
||||
@ -63,7 +64,8 @@ class GloablApproxBuilder {
|
||||
bst_bin_t n_total_bins = 0;
|
||||
partitioner_.clear();
|
||||
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, task_))) {
|
||||
for (auto const &page :
|
||||
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, *task_))) {
|
||||
if (n_total_bins == 0) {
|
||||
n_total_bins = page.cut.TotalBins();
|
||||
feature_values_ = page.cut;
|
||||
@ -157,7 +159,7 @@ class GloablApproxBuilder {
|
||||
void LeafPartition(RegTree const &tree, common::Span<float const> hess,
|
||||
std::vector<bst_node_t> *p_out_position) {
|
||||
monitor_->Start(__func__);
|
||||
if (!task_.UpdateTreeLeaf()) {
|
||||
if (!task_->UpdateTreeLeaf()) {
|
||||
return;
|
||||
}
|
||||
for (auto const &part : partitioner_) {
|
||||
@ -168,8 +170,8 @@ class GloablApproxBuilder {
|
||||
|
||||
public:
|
||||
explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx,
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
|
||||
common::Monitor *monitor)
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||
ObjInfo const *task, common::Monitor *monitor)
|
||||
: param_{param},
|
||||
col_sampler_{std::move(column_sampler)},
|
||||
evaluator_{ctx, param_, info, col_sampler_},
|
||||
@ -256,10 +258,11 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
DMatrix *cached_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
||||
std::make_shared<common::ColumnSampler>();
|
||||
ObjInfo task_;
|
||||
ObjInfo const *task_;
|
||||
|
||||
public:
|
||||
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {
|
||||
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo const *task)
|
||||
: TreeUpdater(ctx), task_{task} {
|
||||
monitor_.Init(__func__);
|
||||
}
|
||||
|
||||
@ -317,5 +320,7 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
|
||||
.describe(
|
||||
"Tree constructor that uses approximate histogram construction "
|
||||
"for each node.")
|
||||
.set_body([](Context const *ctx, ObjInfo task) { return new GlobalApproxUpdater(ctx, task); });
|
||||
.set_body([](Context const *ctx, ObjInfo const *task) {
|
||||
return new GlobalApproxUpdater(ctx, task);
|
||||
});
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -603,5 +603,5 @@ class ColMaker: public TreeUpdater {
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
||||
.describe("Grow tree with parallelization over columns.")
|
||||
.set_body([](Context const *ctx, ObjInfo) { return new ColMaker(ctx); });
|
||||
.set_body([](Context const *ctx, auto) { return new ColMaker(ctx); });
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -15,12 +15,12 @@
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "constraints.cuh"
|
||||
#include "driver.h"
|
||||
#include "gpu_hist/evaluate_splits.cuh"
|
||||
@ -39,11 +39,10 @@
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/task.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace xgboost::tree {
|
||||
#if !defined(GTEST_TEST)
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||
#endif // !defined(GTEST_TEST)
|
||||
@ -106,12 +105,12 @@ class DeviceHistogramStorage {
|
||||
nidx_map_.clear();
|
||||
overflow_nidx_map_.clear();
|
||||
}
|
||||
bool HistogramExists(int nidx) const {
|
||||
[[nodiscard]] bool HistogramExists(int nidx) const {
|
||||
return nidx_map_.find(nidx) != nidx_map_.cend() ||
|
||||
overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
|
||||
}
|
||||
int Bins() const { return n_bins_; }
|
||||
size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
|
||||
[[nodiscard]] int Bins() const { return n_bins_; }
|
||||
[[nodiscard]] size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
|
||||
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }
|
||||
|
||||
void AllocateHistograms(const std::vector<int>& new_nidxs) {
|
||||
@ -690,8 +689,9 @@ struct GPUHistMakerDevice {
|
||||
return root_entry;
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
|
||||
RegTree* p_tree, collective::DeviceCommunicator* communicator,
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
ObjInfo const* task, RegTree* p_tree,
|
||||
collective::DeviceCommunicator* communicator,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto& tree = *p_tree;
|
||||
// Process maximum 32 nodes at a time
|
||||
@ -741,7 +741,7 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
monitor.Start("FinalisePosition");
|
||||
this->FinalisePosition(p_tree, p_fmat, task, p_out_position);
|
||||
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
|
||||
monitor.Stop("FinalisePosition");
|
||||
}
|
||||
};
|
||||
@ -750,7 +750,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
using GradientSumT = GradientPairPrecise;
|
||||
|
||||
public:
|
||||
explicit GPUHistMaker(Context const* ctx, ObjInfo task)
|
||||
explicit GPUHistMaker(Context const* ctx, ObjInfo const* task)
|
||||
: TreeUpdater(ctx), task_{task} {};
|
||||
void Configure(const Args& args) override {
|
||||
// Used in test to count how many configurations are performed
|
||||
@ -872,8 +872,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||
|
||||
char const* Name() const override { return "grow_gpu_hist"; }
|
||||
bool HasNodePosition() const override { return true; }
|
||||
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
|
||||
[[nodiscard]] bool HasNodePosition() const override { return true; }
|
||||
|
||||
private:
|
||||
bool initialised_{false};
|
||||
@ -882,7 +882,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
DMatrix* p_last_fmat_{nullptr};
|
||||
RegTree const* p_last_tree_{nullptr};
|
||||
ObjInfo task_;
|
||||
ObjInfo const* task_{nullptr};
|
||||
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
@ -890,8 +890,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
#if !defined(GTEST_TEST)
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||
.describe("Grow tree with GPU.")
|
||||
.set_body([](Context const* ctx, ObjInfo task) { return new GPUHistMaker(ctx, task); });
|
||||
.set_body([](Context const* ctx, ObjInfo const* task) {
|
||||
return new GPUHistMaker(ctx, task);
|
||||
});
|
||||
#endif // !defined(GTEST_TEST)
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -18,7 +18,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune);
|
||||
/*! \brief pruner that prunes a tree after growing finishes */
|
||||
class TreePruner : public TreeUpdater {
|
||||
public:
|
||||
explicit TreePruner(Context const* ctx, ObjInfo task) : TreeUpdater(ctx) {
|
||||
explicit TreePruner(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx) {
|
||||
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
|
||||
pruner_monitor_.Init("TreePruner");
|
||||
}
|
||||
@ -90,5 +90,7 @@ class TreePruner : public TreeUpdater {
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
||||
.describe("Pruner that prune the tree according to statistics.")
|
||||
.set_body([](Context const* ctx, ObjInfo task) { return new TreePruner(ctx, task); });
|
||||
.set_body([](Context const* ctx, ObjInfo const* task) {
|
||||
return new TreePruner{ctx, task};
|
||||
});
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -35,7 +35,7 @@ void QuantileHistMaker::Update(TrainParam const *param, HostDeviceVector<Gradien
|
||||
// build tree
|
||||
const size_t n_trees = trees.size();
|
||||
if (!pimpl_) {
|
||||
pimpl_.reset(new Builder(n_trees, param, dmat, task_, ctx_));
|
||||
pimpl_.reset(new Builder(n_trees, param, dmat, *task_, ctx_));
|
||||
}
|
||||
|
||||
size_t t_idx{0};
|
||||
@ -287,6 +287,8 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
||||
.describe("Grow tree using quantized histogram.")
|
||||
.set_body([](Context const *ctx, ObjInfo task) { return new QuantileHistMaker(ctx, task); });
|
||||
.set_body([](Context const *ctx, ObjInfo const *task) {
|
||||
return new QuantileHistMaker(ctx, task);
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -43,7 +43,8 @@ inline BatchParam HistBatch(TrainParam const* param) {
|
||||
/*! \brief construct a tree using quantized feature values */
|
||||
class QuantileHistMaker: public TreeUpdater {
|
||||
public:
|
||||
explicit QuantileHistMaker(Context const* ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {}
|
||||
explicit QuantileHistMaker(Context const* ctx, ObjInfo const* task)
|
||||
: TreeUpdater(ctx), task_{task} {}
|
||||
void Configure(const Args&) override {}
|
||||
|
||||
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
@ -125,7 +126,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
|
||||
protected:
|
||||
std::unique_ptr<Builder> pimpl_;
|
||||
ObjInfo task_;
|
||||
ObjInfo const* task_;
|
||||
};
|
||||
} // namespace xgboost::tree
|
||||
|
||||
|
||||
@ -142,5 +142,5 @@ class TreeRefresher : public TreeUpdater {
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
||||
.describe("Refresher that refreshes the weight and statistics according to data.")
|
||||
.set_body([](Context const *ctx, ObjInfo) { return new TreeRefresher(ctx); });
|
||||
.set_body([](Context const *ctx, auto) { return new TreeRefresher(ctx); });
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2013 by XBGoost Contributors
|
||||
* Copyright 2014-2023 by XBGoost Contributors
|
||||
* \file updater_sync.cc
|
||||
* \brief synchronize the tree in all distributed nodes
|
||||
*/
|
||||
@ -53,5 +53,5 @@ class TreeSyncher : public TreeUpdater {
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
||||
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
||||
.set_body([](Context const* ctx, ObjInfo) { return new TreeSyncher(ctx); });
|
||||
.set_body([](Context const* ctx, auto) { return new TreeSyncher(ctx); });
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -170,8 +170,8 @@ void TestHistogramIndexImpl() {
|
||||
|
||||
// Build 2 matrices and build a histogram maker with that
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
tree::GPUHistMaker hist_maker{&ctx, ObjInfo{ObjInfo::kRegression}},
|
||||
hist_maker_ext{&ctx, ObjInfo{ObjInfo::kRegression}};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
tree::GPUHistMaker hist_maker{&ctx, &task}, hist_maker_ext{&ctx, &task};
|
||||
std::unique_ptr<DMatrix> hist_maker_dmat(
|
||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
||||
|
||||
@ -240,7 +240,8 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
param.UpdateAllowUnknown(args);
|
||||
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
tree::GPUHistMaker hist_maker{&ctx,ObjInfo{ObjInfo::kRegression}};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
tree::GPUHistMaker hist_maker{&ctx, &task};
|
||||
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
hist_maker.Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||
@ -385,8 +386,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
||||
|
||||
TEST(GpuHist, ConfigIO) {
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_gpu_hist", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)};
|
||||
updater->Configure(Args{});
|
||||
|
||||
Json j_updater { Object() };
|
||||
|
||||
@ -37,13 +37,13 @@ TEST(GrowHistMaker, InteractionConstraint)
|
||||
auto p_gradients = GenerateGradients(kRows);
|
||||
|
||||
Context ctx;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
{
|
||||
// With constraints
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(
|
||||
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
|
||||
@ -61,8 +61,7 @@ TEST(GrowHistMaker, InteractionConstraint)
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
@ -81,8 +80,8 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
||||
auto p_dmat = GenerateDMatrix(rows, cols);
|
||||
auto p_gradients = GenerateGradients(rows);
|
||||
Context ctx;
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
|
||||
std::unique_ptr<DMatrix> sliced{
|
||||
@ -110,12 +109,12 @@ TEST(GrowHistMaker, ColumnSplit) {
|
||||
|
||||
RegTree expected_tree;
|
||||
expected_tree.param.num_feature = kCols;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
{
|
||||
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
||||
auto p_gradients = GenerateGradients(kRows);
|
||||
Context ctx;
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
|
||||
@ -2,22 +2,25 @@
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/task.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/task.h> // for ObjInfo
|
||||
#include <xgboost/tree_updater.h> // for TreeUpdater
|
||||
|
||||
#include <memory> // for unique_ptr
|
||||
|
||||
namespace xgboost {
|
||||
TEST(Updater, HasNodePosition) {
|
||||
Context ctx;
|
||||
ObjInfo task{ObjInfo::kRegression, true, true};
|
||||
std::unique_ptr<TreeUpdater> up{TreeUpdater::Create("grow_histmaker", &ctx, task)};
|
||||
std::unique_ptr<TreeUpdater> up{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
ASSERT_TRUE(up->HasNodePosition());
|
||||
|
||||
up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, task));
|
||||
up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task));
|
||||
ASSERT_TRUE(up->HasNodePosition());
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
ctx.gpu_id = 0;
|
||||
up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, task));
|
||||
up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, &task));
|
||||
ASSERT_TRUE(up->HasNodePosition());
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -71,8 +72,8 @@ class TestPredictionCache : public ::testing::Test {
|
||||
ctx.gpu_id = Context::kCpuId;
|
||||
}
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create(updater_name, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, &ctx, &task)};
|
||||
RegTree tree;
|
||||
std::vector<RegTree *> trees{&tree};
|
||||
auto gpair = GenerateRandomGradients(n_samples_);
|
||||
|
||||
@ -39,8 +39,8 @@ TEST(Updater, Prune) {
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(cfg);
|
||||
|
||||
std::unique_ptr<TreeUpdater> pruner(
|
||||
TreeUpdater::Create("prune", &ctx, ObjInfo{ObjInfo::kRegression}));
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> pruner(TreeUpdater::Create("prune", &ctx, &task));
|
||||
|
||||
// loss_chg < min_split_loss;
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(trees.size());
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
/**
|
||||
* Copyright 2018-2013 by XGBoost Contributors
|
||||
* Copyright 2018-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/task.h> // for ObjInfo
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <memory>
|
||||
@ -12,9 +13,7 @@
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
namespace xgboost::tree {
|
||||
TEST(Updater, Refresh) {
|
||||
bst_row_t constexpr kRows = 8;
|
||||
bst_feature_t constexpr kCols = 16;
|
||||
@ -33,8 +32,9 @@ TEST(Updater, Refresh) {
|
||||
auto ctx = CreateEmptyGenericParam(GPUIDX);
|
||||
tree.param.UpdateAllowUnknown(cfg);
|
||||
std::vector<RegTree*> trees{&tree};
|
||||
std::unique_ptr<TreeUpdater> refresher(
|
||||
TreeUpdater::Create("refresh", &ctx, ObjInfo{ObjInfo::kRegression}));
|
||||
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &ctx, &task));
|
||||
|
||||
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||
@ -57,6 +57,4 @@ TEST(Updater, Refresh) {
|
||||
ASSERT_NEAR(0, tree.Stat(1).loss_chg, kEps);
|
||||
ASSERT_NEAR(0, tree.Stat(2).loss_chg, kEps);
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -2,9 +2,13 @@
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/task.h> // for ObjInfo
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <memory> // for unique_ptr
|
||||
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
|
||||
@ -26,12 +30,12 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
||||
|
||||
void RunTest(std::string updater) {
|
||||
tree::TrainParam param;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
param.Init(Args{});
|
||||
|
||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||
: CreateEmptyGenericParam(Context::kCpuId));
|
||||
auto up = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up->Configure(Args{});
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
@ -74,18 +78,18 @@ class UpdaterEtaTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void RunTest(std::string updater) {
|
||||
ObjInfo task{ObjInfo::kClassification};
|
||||
|
||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||
: CreateEmptyGenericParam(Context::kCpuId));
|
||||
|
||||
float eta = 0.4;
|
||||
auto up_0 = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||
auto up_0 = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up_0->Configure(Args{});
|
||||
tree::TrainParam param0;
|
||||
param0.Init(Args{{"eta", std::to_string(eta)}});
|
||||
|
||||
auto up_1 = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||
auto up_1 = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up_1->Configure(Args{{"eta", "1.0"}});
|
||||
tree::TrainParam param1;
|
||||
param1.Init(Args{{"eta", "1.0"}});
|
||||
@ -153,11 +157,11 @@ class TestMinSplitLoss : public ::testing::Test {
|
||||
{"gamma", std::to_string(gamma)}};
|
||||
tree::TrainParam param;
|
||||
param.UpdateAllowUnknown(args);
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||
: CreateEmptyGenericParam(Context::kCpuId));
|
||||
auto up = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up->Configure({});
|
||||
|
||||
RegTree tree;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user