diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 59f4c2cf8..02248ed8c 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2014-2022 by XGBoost Contributors +/** + * Copyright 2014-2023 by XGBoost Contributors * \file tree_updater.h * \brief General primitive for tree learning, * Updating a collection of trees given the information. @@ -9,19 +9,17 @@ #define XGBOOST_TREE_UPDATER_H_ #include -#include -#include -#include -#include -#include -#include -#include -#include +#include // for Args, GradientPair +#include // DMatrix +#include // for HostDeviceVector +#include // for VectorView +#include // for Configurable +#include // for Span +#include // for RegTree -#include -#include -#include -#include +#include // for function +#include // for string +#include // for vector namespace xgboost { namespace tree { @@ -30,8 +28,9 @@ struct TrainParam; class Json; struct Context; +struct ObjInfo; -/*! +/** * \brief interface of tree update module, that performs update of a tree. */ class TreeUpdater : public Configurable { @@ -53,12 +52,12 @@ class TreeUpdater : public Configurable { * used for modifying existing trees (like `prune`). Return true if it can modify * existing trees. */ - virtual bool CanModifyTree() const { return false; } + [[nodiscard]] virtual bool CanModifyTree() const { return false; } /*! * \brief Wether the out_position in `Update` is valid. This determines whether adaptive * tree can be used. */ - virtual bool HasNodePosition() const { return false; } + [[nodiscard]] virtual bool HasNodePosition() const { return false; } /** * \brief perform update to the tree models * @@ -91,14 +90,15 @@ class TreeUpdater : public Configurable { return false; } - virtual char const* Name() const = 0; + [[nodiscard]] virtual char const* Name() const = 0; - /*! + /** * \brief Create a tree updater given name * \param name Name of the tree updater. * \param ctx A global runtime parameter + * \param task Infomation about the objective. */ - static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo task); + static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task); }; /*! @@ -106,7 +106,7 @@ class TreeUpdater : public Configurable { */ struct TreeUpdaterReg : public dmlc::FunctionRegEntryBase< - TreeUpdaterReg, std::function> {}; + TreeUpdaterReg, std::function> {}; /*! * \brief Macro to register tree updater. diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 39f38c289..c1cb825c1 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -340,7 +340,7 @@ void GBTree::InitUpdater(Args const& cfg) { // create new updaters for (const std::string& pstr : ups) { std::unique_ptr up( - TreeUpdater::Create(pstr.c_str(), ctx_, model_.learner_model_param->task)); + TreeUpdater::Create(pstr.c_str(), ctx_, &model_.learner_model_param->task)); up->Configure(cfg); updaters_.push_back(std::move(up)); } @@ -448,7 +448,7 @@ void GBTree::LoadConfig(Json const& in) { LOG(WARNING) << "Changing updater from `grow_gpu_hist` to `grow_quantile_histmaker`."; } std::unique_ptr up{ - TreeUpdater::Create(name, ctx_, model_.learner_model_param->task)}; + TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task)}; up->LoadConfig(kv.second); updaters_.push_back(std::move(up)); } diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 286daa4d8..a1d657b82 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -1,20 +1,20 @@ -/*! - * Copyright 2015-2022 by XGBoost Contributors +/** + * Copyright 2015-2023 by XGBoost Contributors * \file tree_updater.cc * \brief Registry of tree updaters. */ +#include "xgboost/tree_updater.h" + #include -#include "xgboost/tree_updater.h" -#include "xgboost/host_device_vector.h" +#include // for string namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg); } // namespace dmlc namespace xgboost { - -TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo task) { +TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo const* task) { auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); if (e == nullptr) { LOG(FATAL) << "Unknown tree updater " << name; @@ -22,11 +22,9 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, Ob auto p_updater = (e->body)(ctx, task); return p_updater; } - } // namespace xgboost -namespace xgboost { -namespace tree { +namespace xgboost::tree { // List of files that will be force linked in static links. DMLC_REGISTRY_LINK_TAG(updater_colmaker); DMLC_REGISTRY_LINK_TAG(updater_refresh); @@ -37,5 +35,4 @@ DMLC_REGISTRY_LINK_TAG(updater_sync); #ifdef XGBOOST_USE_CUDA DMLC_REGISTRY_LINK_TAG(updater_gpu_hist); #endif // XGBOOST_USE_CUDA -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 2bc3ff543..5af2721a6 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -14,14 +14,15 @@ #include "driver.h" #include "hist/evaluate_splits.h" #include "hist/histogram.h" -#include "hist/sampler.h" // SampleGradient +#include "hist/sampler.h" // for SampleGradient #include "param.h" #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/json.h" #include "xgboost/linalg.h" +#include "xgboost/task.h" // for ObjInfo #include "xgboost/tree_model.h" -#include "xgboost/tree_updater.h" +#include "xgboost/tree_updater.h" // for TreeUpdater namespace xgboost::tree { @@ -40,12 +41,12 @@ auto BatchSpec(TrainParam const &p, common::Span hess) { class GloablApproxBuilder { protected: - TrainParam const* param_; + TrainParam const *param_; std::shared_ptr col_sampler_; HistEvaluator evaluator_; HistogramBuilder histogram_builder_; Context const *ctx_; - ObjInfo const task_; + ObjInfo const *const task_; std::vector partitioner_; // Pointer to last updated tree, used for update prediction cache. @@ -63,7 +64,8 @@ class GloablApproxBuilder { bst_bin_t n_total_bins = 0; partitioner_.clear(); // Generating the GHistIndexMatrix is quite slow, is there a way to speed it up? - for (auto const &page : p_fmat->GetBatches(BatchSpec(*param_, hess, task_))) { + for (auto const &page : + p_fmat->GetBatches(BatchSpec(*param_, hess, *task_))) { if (n_total_bins == 0) { n_total_bins = page.cut.TotalBins(); feature_values_ = page.cut; @@ -157,7 +159,7 @@ class GloablApproxBuilder { void LeafPartition(RegTree const &tree, common::Span hess, std::vector *p_out_position) { monitor_->Start(__func__); - if (!task_.UpdateTreeLeaf()) { + if (!task_->UpdateTreeLeaf()) { return; } for (auto const &part : partitioner_) { @@ -168,8 +170,8 @@ class GloablApproxBuilder { public: explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx, - std::shared_ptr column_sampler, ObjInfo task, - common::Monitor *monitor) + std::shared_ptr column_sampler, + ObjInfo const *task, common::Monitor *monitor) : param_{param}, col_sampler_{std::move(column_sampler)}, evaluator_{ctx, param_, info, col_sampler_}, @@ -256,10 +258,11 @@ class GlobalApproxUpdater : public TreeUpdater { DMatrix *cached_{nullptr}; std::shared_ptr column_sampler_ = std::make_shared(); - ObjInfo task_; + ObjInfo const *task_; public: - explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} { + explicit GlobalApproxUpdater(Context const *ctx, ObjInfo const *task) + : TreeUpdater(ctx), task_{task} { monitor_.Init(__func__); } @@ -317,5 +320,7 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker") .describe( "Tree constructor that uses approximate histogram construction " "for each node.") - .set_body([](Context const *ctx, ObjInfo task) { return new GlobalApproxUpdater(ctx, task); }); + .set_body([](Context const *ctx, ObjInfo const *task) { + return new GlobalApproxUpdater(ctx, task); + }); } // namespace xgboost::tree diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 070bfe578..06579c429 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -603,5 +603,5 @@ class ColMaker: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") .describe("Grow tree with parallelization over columns.") - .set_body([](Context const *ctx, ObjInfo) { return new ColMaker(ctx); }); + .set_body([](Context const *ctx, auto) { return new ColMaker(ctx); }); } // namespace xgboost::tree diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 607aa8dc4..54ff7ea1a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -15,12 +15,12 @@ #include "../collective/device_communicator.cuh" #include "../common/bitfield.h" #include "../common/categorical.h" +#include "../common/cuda_context.cuh" // CUDAContext #include "../common/device_helpers.cuh" #include "../common/hist_util.h" #include "../common/io.h" #include "../common/timer.h" #include "../data/ellpack_page.cuh" -#include "../common/cuda_context.cuh" // CUDAContext #include "constraints.cuh" #include "driver.h" #include "gpu_hist/evaluate_splits.cuh" @@ -39,11 +39,10 @@ #include "xgboost/json.h" #include "xgboost/parameter.h" #include "xgboost/span.h" -#include "xgboost/task.h" +#include "xgboost/task.h" // for ObjInfo #include "xgboost/tree_model.h" -namespace xgboost { -namespace tree { +namespace xgboost::tree { #if !defined(GTEST_TEST) DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); #endif // !defined(GTEST_TEST) @@ -106,12 +105,12 @@ class DeviceHistogramStorage { nidx_map_.clear(); overflow_nidx_map_.clear(); } - bool HistogramExists(int nidx) const { + [[nodiscard]] bool HistogramExists(int nidx) const { return nidx_map_.find(nidx) != nidx_map_.cend() || overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend(); } - int Bins() const { return n_bins_; } - size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } + [[nodiscard]] int Bins() const { return n_bins_; } + [[nodiscard]] size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } dh::device_vector& Data() { return data_; } void AllocateHistograms(const std::vector& new_nidxs) { @@ -690,8 +689,9 @@ struct GPUHistMakerDevice { return root_entry; } - void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, ObjInfo task, - RegTree* p_tree, collective::DeviceCommunicator* communicator, + void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, + ObjInfo const* task, RegTree* p_tree, + collective::DeviceCommunicator* communicator, HostDeviceVector* p_out_position) { auto& tree = *p_tree; // Process maximum 32 nodes at a time @@ -741,7 +741,7 @@ struct GPUHistMakerDevice { } monitor.Start("FinalisePosition"); - this->FinalisePosition(p_tree, p_fmat, task, p_out_position); + this->FinalisePosition(p_tree, p_fmat, *task, p_out_position); monitor.Stop("FinalisePosition"); } }; @@ -750,7 +750,7 @@ class GPUHistMaker : public TreeUpdater { using GradientSumT = GradientPairPrecise; public: - explicit GPUHistMaker(Context const* ctx, ObjInfo task) + explicit GPUHistMaker(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx), task_{task} {}; void Configure(const Args& args) override { // Used in test to count how many configurations are performed @@ -872,8 +872,8 @@ class GPUHistMaker : public TreeUpdater { std::unique_ptr> maker; // NOLINT - char const* Name() const override { return "grow_gpu_hist"; } - bool HasNodePosition() const override { return true; } + [[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; } + [[nodiscard]] bool HasNodePosition() const override { return true; } private: bool initialised_{false}; @@ -882,7 +882,7 @@ class GPUHistMaker : public TreeUpdater { DMatrix* p_last_fmat_{nullptr}; RegTree const* p_last_tree_{nullptr}; - ObjInfo task_; + ObjInfo const* task_{nullptr}; common::Monitor monitor_; }; @@ -890,8 +890,8 @@ class GPUHistMaker : public TreeUpdater { #if !defined(GTEST_TEST) XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") .describe("Grow tree with GPU.") - .set_body([](Context const* ctx, ObjInfo task) { return new GPUHistMaker(ctx, task); }); + .set_body([](Context const* ctx, ObjInfo const* task) { + return new GPUHistMaker(ctx, task); + }); #endif // !defined(GTEST_TEST) - -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index c591ce454..0970d2f79 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -18,7 +18,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune); /*! \brief pruner that prunes a tree after growing finishes */ class TreePruner : public TreeUpdater { public: - explicit TreePruner(Context const* ctx, ObjInfo task) : TreeUpdater(ctx) { + explicit TreePruner(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx) { syncher_.reset(TreeUpdater::Create("sync", ctx_, task)); pruner_monitor_.Init("TreePruner"); } @@ -90,5 +90,7 @@ class TreePruner : public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune") .describe("Pruner that prune the tree according to statistics.") - .set_body([](Context const* ctx, ObjInfo task) { return new TreePruner(ctx, task); }); + .set_body([](Context const* ctx, ObjInfo const* task) { + return new TreePruner{ctx, task}; + }); } // namespace xgboost::tree diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 1929efb28..76c402ff5 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -35,7 +35,7 @@ void QuantileHistMaker::Update(TrainParam const *param, HostDeviceVector* gpair, DMatrix* dmat, @@ -125,7 +126,7 @@ class QuantileHistMaker: public TreeUpdater { protected: std::unique_ptr pimpl_; - ObjInfo task_; + ObjInfo const* task_; }; } // namespace xgboost::tree diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index ebda2a999..4bfe603e0 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -142,5 +142,5 @@ class TreeRefresher : public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") .describe("Refresher that refreshes the weight and statistics according to data.") - .set_body([](Context const *ctx, ObjInfo) { return new TreeRefresher(ctx); }); + .set_body([](Context const *ctx, auto) { return new TreeRefresher(ctx); }); } // namespace xgboost::tree diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index bb28bc4e6..2422807e2 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2013 by XBGoost Contributors + * Copyright 2014-2023 by XBGoost Contributors * \file updater_sync.cc * \brief synchronize the tree in all distributed nodes */ @@ -53,5 +53,5 @@ class TreeSyncher : public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") .describe("Syncher that synchronize the tree in all distributed nodes.") - .set_body([](Context const* ctx, ObjInfo) { return new TreeSyncher(ctx); }); + .set_body([](Context const* ctx, auto) { return new TreeSyncher(ctx); }); } // namespace xgboost::tree diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index e828d1379..ed21230ed 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -170,8 +170,8 @@ void TestHistogramIndexImpl() { // Build 2 matrices and build a histogram maker with that Context ctx(CreateEmptyGenericParam(0)); - tree::GPUHistMaker hist_maker{&ctx, ObjInfo{ObjInfo::kRegression}}, - hist_maker_ext{&ctx, ObjInfo{ObjInfo::kRegression}}; + ObjInfo task{ObjInfo::kRegression}; + tree::GPUHistMaker hist_maker{&ctx, &task}, hist_maker_ext{&ctx, &task}; std::unique_ptr hist_maker_dmat( CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); @@ -240,7 +240,8 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, param.UpdateAllowUnknown(args); Context ctx(CreateEmptyGenericParam(0)); - tree::GPUHistMaker hist_maker{&ctx,ObjInfo{ObjInfo::kRegression}}; + ObjInfo task{ObjInfo::kRegression}; + tree::GPUHistMaker hist_maker{&ctx, &task}; std::vector> position(1); hist_maker.Update(¶m, gpair, dmat, common::Span>{position}, @@ -385,8 +386,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) { TEST(GpuHist, ConfigIO) { Context ctx(CreateEmptyGenericParam(0)); - std::unique_ptr updater{ - TreeUpdater::Create("grow_gpu_hist", &ctx, ObjInfo{ObjInfo::kRegression})}; + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)}; updater->Configure(Args{}); Json j_updater { Object() }; diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index 20340f539..aa6a18797 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -37,13 +37,13 @@ TEST(GrowHistMaker, InteractionConstraint) auto p_gradients = GenerateGradients(kRows); Context ctx; + ObjInfo task{ObjInfo::kRegression}; { // With constraints RegTree tree; tree.param.num_feature = kCols; - std::unique_ptr updater{ - TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})}; + std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; TrainParam param; param.UpdateAllowUnknown( Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); @@ -61,8 +61,7 @@ TEST(GrowHistMaker, InteractionConstraint) RegTree tree; tree.param.num_feature = kCols; - std::unique_ptr updater{ - TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})}; + std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; std::vector> position(1); TrainParam param; param.Init(Args{}); @@ -81,8 +80,8 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) { auto p_dmat = GenerateDMatrix(rows, cols); auto p_gradients = GenerateGradients(rows); Context ctx; - std::unique_ptr updater{ - TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})}; + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; std::vector> position(1); std::unique_ptr sliced{ @@ -110,12 +109,12 @@ TEST(GrowHistMaker, ColumnSplit) { RegTree expected_tree; expected_tree.param.num_feature = kCols; + ObjInfo task{ObjInfo::kRegression}; { auto p_dmat = GenerateDMatrix(kRows, kCols); auto p_gradients = GenerateGradients(kRows); Context ctx; - std::unique_ptr updater{ - TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})}; + std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; std::vector> position(1); TrainParam param; param.Init(Args{}); diff --git a/tests/cpp/tree/test_node_partition.cc b/tests/cpp/tree/test_node_partition.cc index 883c8e68f..d7254fa60 100644 --- a/tests/cpp/tree/test_node_partition.cc +++ b/tests/cpp/tree/test_node_partition.cc @@ -2,22 +2,25 @@ * Copyright 2023 by XGBoost contributors */ #include -#include -#include +#include // for Context +#include // for ObjInfo +#include // for TreeUpdater + +#include // for unique_ptr namespace xgboost { TEST(Updater, HasNodePosition) { Context ctx; ObjInfo task{ObjInfo::kRegression, true, true}; - std::unique_ptr up{TreeUpdater::Create("grow_histmaker", &ctx, task)}; + std::unique_ptr up{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; ASSERT_TRUE(up->HasNodePosition()); - up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, task)); + up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)); ASSERT_TRUE(up->HasNodePosition()); #if defined(XGBOOST_USE_CUDA) ctx.gpu_id = 0; - up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, task)); + up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, &task)); ASSERT_TRUE(up->HasNodePosition()); #endif // defined(XGBOOST_USE_CUDA) } diff --git a/tests/cpp/tree/test_prediction_cache.cc b/tests/cpp/tree/test_prediction_cache.cc index f4e67d836..4f5a05eb6 100644 --- a/tests/cpp/tree/test_prediction_cache.cc +++ b/tests/cpp/tree/test_prediction_cache.cc @@ -9,6 +9,7 @@ #include "../../../src/tree/param.h" // for TrainParam #include "../helpers.h" +#include "xgboost/task.h" // for ObjInfo namespace xgboost { @@ -71,8 +72,8 @@ class TestPredictionCache : public ::testing::Test { ctx.gpu_id = Context::kCpuId; } - std::unique_ptr updater{ - TreeUpdater::Create(updater_name, &ctx, ObjInfo{ObjInfo::kRegression})}; + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr updater{TreeUpdater::Create(updater_name, &ctx, &task)}; RegTree tree; std::vector trees{&tree}; auto gpair = GenerateRandomGradients(n_samples_); diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index 258396976..063816def 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -39,8 +39,8 @@ TEST(Updater, Prune) { TrainParam param; param.UpdateAllowUnknown(cfg); - std::unique_ptr pruner( - TreeUpdater::Create("prune", &ctx, ObjInfo{ObjInfo::kRegression})); + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr pruner(TreeUpdater::Create("prune", &ctx, &task)); // loss_chg < min_split_loss; std::vector> position(trees.size()); diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index 870022724..80a0cbe6f 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -1,8 +1,9 @@ /** - * Copyright 2018-2013 by XGBoost Contributors + * Copyright 2018-2023 by XGBoost Contributors */ #include #include +#include // for ObjInfo #include #include @@ -12,9 +13,7 @@ #include "../../../src/tree/param.h" // for TrainParam #include "../helpers.h" -namespace xgboost { -namespace tree { - +namespace xgboost::tree { TEST(Updater, Refresh) { bst_row_t constexpr kRows = 8; bst_feature_t constexpr kCols = 16; @@ -33,8 +32,9 @@ TEST(Updater, Refresh) { auto ctx = CreateEmptyGenericParam(GPUIDX); tree.param.UpdateAllowUnknown(cfg); std::vector trees{&tree}; - std::unique_ptr refresher( - TreeUpdater::Create("refresh", &ctx, ObjInfo{ObjInfo::kRegression})); + + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr refresher(TreeUpdater::Create("refresh", &ctx, &task)); tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f, /*left_sum=*/0.0f, /*right_sum=*/0.0f); @@ -57,6 +57,4 @@ TEST(Updater, Refresh) { ASSERT_NEAR(0, tree.Stat(1).loss_chg, kEps); ASSERT_NEAR(0, tree.Stat(2).loss_chg, kEps); } - -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index 4757bb3c1..a3f5cf9d3 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -2,9 +2,13 @@ * Copyright 2020-2023 by XGBoost Contributors */ #include +#include // for Context +#include // for ObjInfo #include #include +#include // for unique_ptr + #include "../../../src/tree/param.h" // for TrainParam #include "../helpers.h" @@ -26,12 +30,12 @@ class UpdaterTreeStatTest : public ::testing::Test { void RunTest(std::string updater) { tree::TrainParam param; + ObjInfo task{ObjInfo::kRegression}; param.Init(Args{}); Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0) : CreateEmptyGenericParam(Context::kCpuId)); - auto up = std::unique_ptr{ - TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})}; + auto up = std::unique_ptr{TreeUpdater::Create(updater, &ctx, &task)}; up->Configure(Args{}); RegTree tree; tree.param.num_feature = kCols; @@ -74,18 +78,18 @@ class UpdaterEtaTest : public ::testing::Test { } void RunTest(std::string updater) { + ObjInfo task{ObjInfo::kClassification}; + Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0) : CreateEmptyGenericParam(Context::kCpuId)); float eta = 0.4; - auto up_0 = std::unique_ptr{ - TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})}; + auto up_0 = std::unique_ptr{TreeUpdater::Create(updater, &ctx, &task)}; up_0->Configure(Args{}); tree::TrainParam param0; param0.Init(Args{{"eta", std::to_string(eta)}}); - auto up_1 = std::unique_ptr{ - TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})}; + auto up_1 = std::unique_ptr{TreeUpdater::Create(updater, &ctx, &task)}; up_1->Configure(Args{{"eta", "1.0"}}); tree::TrainParam param1; param1.Init(Args{{"eta", "1.0"}}); @@ -153,11 +157,11 @@ class TestMinSplitLoss : public ::testing::Test { {"gamma", std::to_string(gamma)}}; tree::TrainParam param; param.UpdateAllowUnknown(args); + ObjInfo task{ObjInfo::kRegression}; Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0) : CreateEmptyGenericParam(Context::kCpuId)); - auto up = std::unique_ptr{ - TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})}; + auto up = std::unique_ptr{TreeUpdater::Create(updater, &ctx, &task)}; up->Configure({}); RegTree tree;