From 9bade7203a5744ba4d5b226bedb0dbf7a980081f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 13 Mar 2023 20:55:10 +0800 Subject: [PATCH] Remove public access to tree model param. (#8902) * Make tree model param a private member. * Number of features and targets are immutable after construction. This is to reduce the number of places where we can run configuration. --- include/xgboost/tree_model.h | 146 +++++++++--------- src/gbm/gbtree.cc | 4 +- src/learner.cc | 2 - src/predictor/cpu_predictor.cc | 2 +- src/tree/tree_model.cc | 78 +++++----- src/tree/updater_colmaker.cc | 6 +- src/tree/updater_prune.cc | 2 +- src/tree/updater_refresh.cc | 8 +- tests/cpp/tree/test_histmaker.cc | 14 +- .../cpp/tree/test_multi_target_tree_model.cc | 17 +- tests/cpp/tree/test_prune.cc | 3 +- tests/cpp/tree/test_refresh.cc | 3 +- tests/cpp/tree/test_tree_model.cc | 14 +- tests/cpp/tree/test_tree_stat.cc | 9 +- 14 files changed, 149 insertions(+), 159 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index f646140dc..61dd94302 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -178,51 +178,33 @@ class RegTree : public Model { } /*! \brief index of left child */ - XGBOOST_DEVICE [[nodiscard]] int LeftChild() const { - return this->cleft_; - } + [[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; } /*! \brief index of right child */ - XGBOOST_DEVICE [[nodiscard]] int RightChild() const { - return this->cright_; - } + [[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; } /*! \brief index of default child when feature is missing */ - XGBOOST_DEVICE [[nodiscard]] int DefaultChild() const { + [[nodiscard]] XGBOOST_DEVICE int DefaultChild() const { return this->DefaultLeft() ? this->LeftChild() : this->RightChild(); } /*! \brief feature index of split condition */ - XGBOOST_DEVICE [[nodiscard]] unsigned SplitIndex() const { + [[nodiscard]] XGBOOST_DEVICE unsigned SplitIndex() const { return sindex_ & ((1U << 31) - 1U); } /*! \brief when feature is unknown, whether goes to left child */ - XGBOOST_DEVICE [[nodiscard]] bool DefaultLeft() const { - return (sindex_ >> 31) != 0; - } + [[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; } /*! \brief whether current node is leaf node */ - XGBOOST_DEVICE [[nodiscard]] bool IsLeaf() const { - return cleft_ == kInvalidNodeId; - } + [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; } /*! \return get leaf value of leaf node */ - XGBOOST_DEVICE [[nodiscard]] float LeafValue() const { - return (this->info_).leaf_value; - } + [[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; } /*! \return get split condition of the node */ - XGBOOST_DEVICE [[nodiscard]] SplitCondT SplitCond() const { - return (this->info_).split_cond; - } + [[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; } /*! \brief get parent of the node */ - XGBOOST_DEVICE [[nodiscard]] int Parent() const { - return parent_ & ((1U << 31) - 1); - } + [[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); } /*! \brief whether current node is left child */ - XGBOOST_DEVICE [[nodiscard]] bool IsLeftChild() const { - return (parent_ & (1U << 31)) != 0; - } + [[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; } /*! \brief whether this node is deleted */ - XGBOOST_DEVICE [[nodiscard]] bool IsDeleted() const { - return sindex_ == kDeletedNodeMarker; - } + [[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; } /*! \brief whether current node is root */ - XGBOOST_DEVICE [[nodiscard]] bool IsRoot() const { return parent_ == kInvalidNodeId; } + [[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; } /*! * \brief set the left child * \param nid node id to right child @@ -337,15 +319,13 @@ class RegTree : public Model { this->ChangeToLeaf(rid, value); } - /*! \brief model parameter */ - TreeParam param; RegTree() { - param.Init(Args{}); - nodes_.resize(param.num_nodes); - stats_.resize(param.num_nodes); - split_types_.resize(param.num_nodes, FeatureType::kNumerical); - split_categories_segments_.resize(param.num_nodes); - for (int i = 0; i < param.num_nodes; i++) { + param_.Init(Args{}); + nodes_.resize(param_.num_nodes); + stats_.resize(param_.num_nodes); + split_types_.resize(param_.num_nodes, FeatureType::kNumerical); + split_categories_segments_.resize(param_.num_nodes); + for (int i = 0; i < param_.num_nodes; i++) { nodes_[i].SetLeaf(0.0f); nodes_[i].SetParent(kInvalidNodeId); } @@ -354,10 +334,10 @@ class RegTree : public Model { * \brief Constructor that initializes the tree model with shape. */ explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} { - param.num_feature = n_features; - param.size_leaf_vector = n_targets; + param_.num_feature = n_features; + param_.size_leaf_vector = n_targets; if (n_targets > 1) { - this->p_mt_tree_.reset(new MultiTargetTree{¶m}); + this->p_mt_tree_.reset(new MultiTargetTree{¶m_}); } } @@ -401,7 +381,7 @@ class RegTree : public Model { bool operator==(const RegTree& b) const { return nodes_ == b.nodes_ && stats_ == b.stats_ && - deleted_nodes_ == b.deleted_nodes_ && param == b.param; + deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_; } /* \brief Iterate through all nodes in this tree. * @@ -459,7 +439,9 @@ class RegTree : public Model { bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child = kInvalidNodeId); - + /** + * \brief Expands a leaf node into two additional leaf nodes for a multi-target tree. + */ void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView base_weight, linalg::VectorView left_weight, @@ -485,19 +467,48 @@ class RegTree : public Model { bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum); - - [[nodiscard]] bool HasCategoricalSplit() const { - return !split_categories_.empty(); - } + /** + * \brief Whether this tree has categorical split. + */ + [[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); } /** * \brief Whether this is a multi-target tree. */ [[nodiscard]] bool IsMultiTarget() const { return static_cast(p_mt_tree_); } - [[nodiscard]] bst_target_t NumTargets() const { return param.size_leaf_vector; } + /** + * \brief The size of leaf weight. + */ + [[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; } + /** + * \brief Get the underlying implementaiton of multi-target tree. + */ [[nodiscard]] auto GetMultiTargetTree() const { CHECK(IsMultiTarget()); return p_mt_tree_.get(); } + /** + * \brief Get the number of features. + */ + [[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; } + /** + * \brief Get the total number of nodes including deleted ones in this tree. + */ + [[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; } + /** + * \brief Get the total number of valid nodes in this tree. + */ + [[nodiscard]] bst_node_t NumValidNodes() const noexcept { + return param_.num_nodes - param_.num_deleted; + } + /** + * \brief number of extra nodes besides the root + */ + [[nodiscard]] bst_node_t NumExtraNodes() const noexcept { + return param_.num_nodes - 1 - param_.num_deleted; + } + /* \brief Count number of leaves in tree. */ + [[nodiscard]] bst_node_t GetNumLeaves() const; + [[nodiscard]] bst_node_t GetNumSplitNodes() const; /*! * \brief get current depth @@ -514,6 +525,9 @@ class RegTree : public Model { } return depth; } + /** + * \brief Set the leaf weight for a multi-target tree. + */ void SetLeaf(bst_node_t nidx, linalg::VectorView weight) { CHECK(IsMultiTarget()); return this->p_mt_tree_->SetLeaf(nidx, weight); @@ -525,25 +539,13 @@ class RegTree : public Model { */ [[nodiscard]] int MaxDepth(int nid) const { if (nodes_[nid].IsLeaf()) return 0; - return std::max(MaxDepth(nodes_[nid].LeftChild())+1, - MaxDepth(nodes_[nid].RightChild())+1); + return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1); } /*! * \brief get maximum depth */ - int MaxDepth() { - return MaxDepth(0); - } - - /*! \brief number of extra nodes besides the root */ - [[nodiscard]] int NumExtraNodes() const { - return param.num_nodes - 1 - param.num_deleted; - } - - /* \brief Count number of leaves in tree. */ - [[nodiscard]] bst_node_t GetNumLeaves() const; - [[nodiscard]] bst_node_t GetNumSplitNodes() const; + int MaxDepth() { return MaxDepth(0); } /*! * \brief dense feature vector that can be taken by RegTree @@ -735,6 +737,8 @@ class RegTree : public Model { template void LoadCategoricalSplit(Json const& in); void SaveCategoricalSplit(Json* p_out) const; + /*! \brief model parameter */ + TreeParam param_; // vector of nodes std::vector nodes_; // free node space, used during training process @@ -752,20 +756,20 @@ class RegTree : public Model { // allocate a new node, // !!!!!! NOTE: may cause BUG here, nodes.resize bst_node_t AllocNode() { - if (param.num_deleted != 0) { + if (param_.num_deleted != 0) { int nid = deleted_nodes_.back(); deleted_nodes_.pop_back(); nodes_[nid].Reuse(); - --param.num_deleted; + --param_.num_deleted; return nid; } - int nd = param.num_nodes++; - CHECK_LT(param.num_nodes, std::numeric_limits::max()) + int nd = param_.num_nodes++; + CHECK_LT(param_.num_nodes, std::numeric_limits::max()) << "number of nodes in the tree exceed 2^31"; - nodes_.resize(param.num_nodes); - stats_.resize(param.num_nodes); - split_types_.resize(param.num_nodes, FeatureType::kNumerical); - split_categories_segments_.resize(param.num_nodes); + nodes_.resize(param_.num_nodes); + stats_.resize(param_.num_nodes); + split_types_.resize(param_.num_nodes, FeatureType::kNumerical); + split_categories_segments_.resize(param_.num_nodes); return nd; } // delete a tree node, keep the parent field to allow trace back @@ -780,7 +784,7 @@ class RegTree : public Model { deleted_nodes_.push_back(nid); nodes_[nid].MarkDelete(); - ++param.num_deleted; + ++param_.num_deleted; } }; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index c1cb825c1..16609619c 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -360,8 +360,8 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fma << "Set `process_type` to `update` if you want to update existing " "trees."; // create new tree - std::unique_ptr ptr(new RegTree()); - ptr->param.UpdateAllowUnknown(this->cfg_); + std::unique_ptr ptr(new RegTree{this->model_.learner_model_param->LeafLength(), + this->model_.learner_model_param->num_feature}); new_trees.push_back(ptr.get()); ret->push_back(std::move(ptr)); } else if (tparam_.process_type == TreeProcessType::kUpdate) { diff --git a/src/learner.cc b/src/learner.cc index 454855355..62875ead6 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -775,8 +775,6 @@ class LearnerConfiguration : public Learner { } CHECK_NE(mparam_.num_feature, 0) << "0 feature is supplied. Are you using raw Booster interface?"; - // Remove these once binary IO is gone. - cfg_["num_feature"] = common::ToString(mparam_.num_feature); } void ConfigureGBM(LearnerTrainParam const& old, Args const& args) { diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 4473173d2..a4b78fefd 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -275,7 +275,7 @@ float FillNodeMeanValues(RegTree const *tree, bst_node_t nidx, std::vector* mean_values) { - size_t num_nodes = tree->param.num_nodes; + size_t num_nodes = tree->NumNodes(); if (mean_values->size() == num_nodes) { return; } diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 0891ec3b2..8f297f46d 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -815,9 +815,9 @@ void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split linalg::VectorView left_weight, linalg::VectorView right_weight) { CHECK(IsMultiTarget()); - CHECK_LT(split_index, this->param.num_feature); + CHECK_LT(split_index, this->param_.num_feature); CHECK(this->p_mt_tree_); - CHECK_GT(param.size_leaf_vector, 1); + CHECK_GT(param_.size_leaf_vector, 1); this->p_mt_tree_->Expand(nidx, split_index, split_cond, default_left, base_weight, left_weight, right_weight); @@ -826,7 +826,7 @@ void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split split_categories_segments_.resize(this->Size()); this->split_types_.at(nidx) = FeatureType::kNumerical; - this->param.num_nodes = this->p_mt_tree_->Size(); + this->param_.num_nodes = this->p_mt_tree_->Size(); } void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index, @@ -850,13 +850,13 @@ void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index, } void RegTree::Load(dmlc::Stream* fi) { - CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam)); + CHECK_EQ(fi->Read(¶m_, sizeof(TreeParam)), sizeof(TreeParam)); if (!DMLC_IO_NO_ENDIAN_SWAP) { - param = param.ByteSwap(); + param_ = param_.ByteSwap(); } - nodes_.resize(param.num_nodes); - stats_.resize(param.num_nodes); - CHECK_NE(param.num_nodes, 0); + nodes_.resize(param_.num_nodes); + stats_.resize(param_.num_nodes); + CHECK_NE(param_.num_nodes, 0); CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()), sizeof(Node) * nodes_.size()); if (!DMLC_IO_NO_ENDIAN_SWAP) { @@ -873,29 +873,29 @@ void RegTree::Load(dmlc::Stream* fi) { } // chg deleted nodes deleted_nodes_.resize(0); - for (int i = 1; i < param.num_nodes; ++i) { + for (int i = 1; i < param_.num_nodes; ++i) { if (nodes_[i].IsDeleted()) { deleted_nodes_.push_back(i); } } - CHECK_EQ(static_cast(deleted_nodes_.size()), param.num_deleted); + CHECK_EQ(static_cast(deleted_nodes_.size()), param_.num_deleted); - split_types_.resize(param.num_nodes, FeatureType::kNumerical); - split_categories_segments_.resize(param.num_nodes); + split_types_.resize(param_.num_nodes, FeatureType::kNumerical); + split_categories_segments_.resize(param_.num_nodes); } void RegTree::Save(dmlc::Stream* fo) const { - CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); - CHECK_EQ(param.num_nodes, static_cast(stats_.size())); - CHECK_EQ(param.deprecated_num_roots, 1); - CHECK_NE(param.num_nodes, 0); + CHECK_EQ(param_.num_nodes, static_cast(nodes_.size())); + CHECK_EQ(param_.num_nodes, static_cast(stats_.size())); + CHECK_EQ(param_.deprecated_num_roots, 1); + CHECK_NE(param_.num_nodes, 0); CHECK(!HasCategoricalSplit()) << "Please use JSON/UBJSON for saving models with categorical splits."; if (DMLC_IO_NO_ENDIAN_SWAP) { - fo->Write(¶m, sizeof(TreeParam)); + fo->Write(¶m_, sizeof(TreeParam)); } else { - TreeParam x = param.ByteSwap(); + TreeParam x = param_.ByteSwap(); fo->Write(&x, sizeof(x)); } @@ -1081,7 +1081,7 @@ void RegTree::LoadModel(Json const& in) { bool typed = IsA(in[tf::kParent]); auto const& in_obj = get(in); // basic properties - FromJson(in["tree_param"], ¶m); + FromJson(in["tree_param"], ¶m_); // categorical splits bool has_cat = in_obj.find("split_type") != in_obj.cend(); if (has_cat) { @@ -1092,55 +1092,55 @@ void RegTree::LoadModel(Json const& in) { } } // multi-target - if (param.size_leaf_vector > 1) { - this->p_mt_tree_.reset(new MultiTargetTree{¶m}); + if (param_.size_leaf_vector > 1) { + this->p_mt_tree_.reset(new MultiTargetTree{¶m_}); this->GetMultiTargetTree()->LoadModel(in); return; } bool feature_is_64 = IsA(in["split_indices"]); if (typed && feature_is_64) { - LoadModelImpl(in, param, &stats_, &nodes_); + LoadModelImpl(in, param_, &stats_, &nodes_); } else if (typed && !feature_is_64) { - LoadModelImpl(in, param, &stats_, &nodes_); + LoadModelImpl(in, param_, &stats_, &nodes_); } else if (!typed && feature_is_64) { - LoadModelImpl(in, param, &stats_, &nodes_); + LoadModelImpl(in, param_, &stats_, &nodes_); } else { - LoadModelImpl(in, param, &stats_, &nodes_); + LoadModelImpl(in, param_, &stats_, &nodes_); } if (!has_cat) { - this->split_categories_segments_.resize(this->param.num_nodes); - this->split_types_.resize(this->param.num_nodes); + this->split_categories_segments_.resize(this->param_.num_nodes); + this->split_types_.resize(this->param_.num_nodes); std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical); } deleted_nodes_.clear(); - for (bst_node_t i = 1; i < param.num_nodes; ++i) { + for (bst_node_t i = 1; i < param_.num_nodes; ++i) { if (nodes_[i].IsDeleted()) { deleted_nodes_.push_back(i); } } // easier access to [] operator auto& self = *this; - for (auto nid = 1; nid < param.num_nodes; ++nid) { + for (auto nid = 1; nid < param_.num_nodes; ++nid) { auto parent = self[nid].Parent(); CHECK_NE(parent, RegTree::kInvalidNodeId); self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid); } - CHECK_EQ(static_cast(deleted_nodes_.size()), param.num_deleted); - CHECK_EQ(this->split_categories_segments_.size(), param.num_nodes); + CHECK_EQ(static_cast(deleted_nodes_.size()), param_.num_deleted); + CHECK_EQ(this->split_categories_segments_.size(), param_.num_nodes); } void RegTree::SaveModel(Json* p_out) const { auto& out = *p_out; // basic properties - out["tree_param"] = ToJson(param); + out["tree_param"] = ToJson(param_); // categorical splits this->SaveCategoricalSplit(p_out); // multi-target if (this->IsMultiTarget()) { - CHECK_GT(param.size_leaf_vector, 1); + CHECK_GT(param_.size_leaf_vector, 1); this->GetMultiTargetTree()->SaveModel(p_out); return; } @@ -1150,11 +1150,11 @@ void RegTree::SaveModel(Json* p_out) const { * pruner, and this pruner can be used inside another updater so leaf are not necessary * at the end of node array. */ - CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); - CHECK_EQ(param.num_nodes, static_cast(stats_.size())); + CHECK_EQ(param_.num_nodes, static_cast(nodes_.size())); + CHECK_EQ(param_.num_nodes, static_cast(stats_.size())); - CHECK_EQ(get(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes)); - auto n_nodes = param.num_nodes; + CHECK_EQ(get(out["tree_param"]["num_nodes"]), std::to_string(param_.num_nodes)); + auto n_nodes = param_.num_nodes; // stats F32Array loss_changes(n_nodes); @@ -1168,7 +1168,7 @@ void RegTree::SaveModel(Json* p_out) const { F32Array conds(n_nodes); U8Array default_left(n_nodes); - CHECK_EQ(this->split_types_.size(), param.num_nodes); + CHECK_EQ(this->split_types_.size(), param_.num_nodes); namespace tf = tree_field; @@ -1189,7 +1189,7 @@ void RegTree::SaveModel(Json* p_out) const { default_left.Set(i, static_cast(!!n.DefaultLeft())); } }; - if (this->param.num_feature > static_cast(std::numeric_limits::max())) { + if (this->param_.num_feature > static_cast(std::numeric_limits::max())) { I64Array indices_64(n_nodes); save_tree(&indices_64); out[tf::kSplitIdx] = std::move(indices_64); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 06579c429..02edfa74a 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -190,7 +190,7 @@ class ColMaker: public TreeUpdater { (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); } // remember auxiliary statistics in the tree node - for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { + for (int nid = 0; nid < p_tree->NumNodes(); ++nid) { p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; p_tree->Stat(nid).base_weight = snode_[nid].weight; p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.sum_hess); @@ -255,9 +255,9 @@ class ColMaker: public TreeUpdater { { // setup statistics space for each tree node for (auto& i : stemp_) { - i.resize(tree.param.num_nodes, ThreadEntry()); + i.resize(tree.NumNodes(), ThreadEntry()); } - snode_.resize(tree.param.num_nodes, NodeEntry()); + snode_.resize(tree.NumNodes(), NodeEntry()); } const MetaInfo& info = fmat.Info(); // setup position diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 0970d2f79..29f9917ba 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -72,7 +72,7 @@ class TreePruner : public TreeUpdater { void DoPrune(TrainParam const* param, RegTree* p_tree) { auto& tree = *p_tree; bst_node_t npruned = 0; - for (int nid = 0; nid < tree.param.num_nodes; ++nid) { + for (int nid = 0; nid < tree.NumNodes(); ++nid) { if (tree[nid].IsLeaf() && !tree[nid].IsDeleted()) { npruned = this->TryPruneLeaf(param, p_tree, nid, tree.GetDepth(nid), npruned); } diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 4bfe603e0..17c565490 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -50,11 +50,11 @@ class TreeRefresher : public TreeUpdater { int tid = omp_get_thread_num(); int num_nodes = 0; for (auto tree : trees) { - num_nodes += tree->param.num_nodes; + num_nodes += tree->NumNodes(); } stemp[tid].resize(num_nodes, GradStats()); std::fill(stemp[tid].begin(), stemp[tid].end(), GradStats()); - fvec_temp[tid].Init(trees[0]->param.num_feature); + fvec_temp[tid].Init(trees[0]->NumFeatures()); }); } exc.Rethrow(); @@ -77,7 +77,7 @@ class TreeRefresher : public TreeUpdater { for (auto tree : trees) { AddStats(*tree, feats, gpair_h, info, ridx, dmlc::BeginPtr(stemp[tid]) + offset); - offset += tree->param.num_nodes; + offset += tree->NumNodes(); } feats.Drop(inst); }); @@ -96,7 +96,7 @@ class TreeRefresher : public TreeUpdater { int offset = 0; for (auto tree : trees) { this->Refresh(param, dmlc::BeginPtr(stemp[0]) + offset, 0, tree); - offset += tree->param.num_nodes; + offset += tree->NumNodes(); } } diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index aa6a18797..881de57e1 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -40,8 +40,7 @@ TEST(GrowHistMaker, InteractionConstraint) ObjInfo task{ObjInfo::kRegression}; { // With constraints - RegTree tree; - tree.param.num_feature = kCols; + RegTree tree{1, kCols}; std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; TrainParam param; @@ -58,8 +57,7 @@ TEST(GrowHistMaker, InteractionConstraint) } { // Without constraints - RegTree tree; - tree.param.num_feature = kCols; + RegTree tree{1u, kCols}; std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; std::vector> position(1); @@ -76,7 +74,7 @@ TEST(GrowHistMaker, InteractionConstraint) } namespace { -void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) { +void TestColumnSplit(int32_t rows, bst_feature_t cols, RegTree const& expected_tree) { auto p_dmat = GenerateDMatrix(rows, cols); auto p_gradients = GenerateGradients(rows); Context ctx; @@ -87,8 +85,7 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) { std::unique_ptr sliced{ p_dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())}; - RegTree tree; - tree.param.num_feature = cols; + RegTree tree{1u, cols}; TrainParam param; param.Init(Args{}); updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree}); @@ -107,8 +104,7 @@ TEST(GrowHistMaker, ColumnSplit) { auto constexpr kRows = 32; auto constexpr kCols = 16; - RegTree expected_tree; - expected_tree.param.num_feature = kCols; + RegTree expected_tree{1u, kCols}; ObjInfo task{ObjInfo::kRegression}; { auto p_dmat = GenerateDMatrix(kRows, kCols); diff --git a/tests/cpp/tree/test_multi_target_tree_model.cc b/tests/cpp/tree/test_multi_target_tree_model.cc index 7d2bd9c7c..af83ed7eb 100644 --- a/tests/cpp/tree/test_multi_target_tree_model.cc +++ b/tests/cpp/tree/test_multi_target_tree_model.cc @@ -17,8 +17,8 @@ TEST(MultiTargetTree, JsonIO) { linalg::Vector right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, Context::kCpuId}; tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(), left_weight.HostView(), right_weight.HostView()); - ASSERT_EQ(tree.param.num_nodes, 3); - ASSERT_EQ(tree.param.size_leaf_vector, 3); + ASSERT_EQ(tree.NumNodes(), 3); + ASSERT_EQ(tree.NumTargets(), 3); ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3); ASSERT_EQ(tree.Size(), 3); @@ -26,20 +26,19 @@ TEST(MultiTargetTree, JsonIO) { tree.SaveModel(&jtree); auto check_jtree = [](Json jtree, RegTree const& tree) { - ASSERT_EQ(get(jtree["tree_param"]["num_nodes"]), - std::to_string(tree.param.num_nodes)); + ASSERT_EQ(get(jtree["tree_param"]["num_nodes"]), std::to_string(tree.NumNodes())); ASSERT_EQ(get(jtree["base_weights"]).size(), - tree.param.num_nodes * tree.param.size_leaf_vector); - ASSERT_EQ(get(jtree["parents"]).size(), tree.param.num_nodes); - ASSERT_EQ(get(jtree["left_children"]).size(), tree.param.num_nodes); - ASSERT_EQ(get(jtree["right_children"]).size(), tree.param.num_nodes); + tree.NumNodes() * tree.NumTargets()); + ASSERT_EQ(get(jtree["parents"]).size(), tree.NumNodes()); + ASSERT_EQ(get(jtree["left_children"]).size(), tree.NumNodes()); + ASSERT_EQ(get(jtree["right_children"]).size(), tree.NumNodes()); }; check_jtree(jtree, tree); RegTree loaded; loaded.LoadModel(jtree); ASSERT_TRUE(loaded.IsMultiTarget()); - ASSERT_EQ(loaded.param.num_nodes, 3); + ASSERT_EQ(loaded.NumNodes(), 3); Json jtree1{Object{}}; loaded.SaveModel(&jtree1); diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index 063816def..78161cac9 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -32,8 +32,7 @@ TEST(Updater, Prune) { auto ctx = CreateEmptyGenericParam(GPUIDX); // prepare tree - RegTree tree = RegTree(); - tree.param.UpdateAllowUnknown(cfg); + RegTree tree = RegTree{1u, kCols}; std::vector trees {&tree}; // prepare pruner TrainParam param; diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index 80a0cbe6f..f46ec2880 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -28,9 +28,8 @@ TEST(Updater, Refresh) { {"num_feature", std::to_string(kCols)}, {"reg_lambda", "1"}}; - RegTree tree = RegTree(); + RegTree tree = RegTree{1u, kCols}; auto ctx = CreateEmptyGenericParam(GPUIDX); - tree.param.UpdateAllowUnknown(cfg); std::vector trees{&tree}; ObjInfo task{ObjInfo::kRegression}; diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 130a0ef70..44708ebd1 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -11,9 +11,8 @@ namespace xgboost { TEST(Tree, ModelShape) { bst_feature_t n_features = std::numeric_limits::max(); - RegTree tree; - tree.param.UpdateAllowUnknown(Args{{"num_feature", std::to_string(n_features)}}); - ASSERT_EQ(tree.param.num_feature, n_features); + RegTree tree{1u, n_features}; + ASSERT_EQ(tree.NumFeatures(), n_features); dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/tree.model"; @@ -27,7 +26,7 @@ TEST(Tree, ModelShape) { RegTree new_tree; std::unique_ptr fi(dmlc::Stream::Create(tmp_file.c_str(), "r")); new_tree.Load(fi.get()); - ASSERT_EQ(new_tree.param.num_feature, n_features); + ASSERT_EQ(new_tree.NumFeatures(), n_features); } { // json @@ -39,7 +38,7 @@ TEST(Tree, ModelShape) { auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()}); new_tree.LoadModel(j_loaded); - ASSERT_EQ(new_tree.param.num_feature, n_features); + ASSERT_EQ(new_tree.NumFeatures(), n_features); } { // ubjson @@ -51,7 +50,7 @@ TEST(Tree, ModelShape) { auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()}, std::ios::binary); new_tree.LoadModel(j_loaded); - ASSERT_EQ(new_tree.param.num_feature, n_features); + ASSERT_EQ(new_tree.NumFeatures(), n_features); } } @@ -488,8 +487,7 @@ TEST(Tree, JsonIO) { RegTree loaded_tree; loaded_tree.LoadModel(j_tree); - ASSERT_EQ(loaded_tree.param.num_nodes, 3); - + ASSERT_EQ(loaded_tree.NumNodes(), 3); ASSERT_TRUE(loaded_tree == tree); auto left = tree[0].LeftChild(); diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index a3f5cf9d3..07c51dfcc 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -37,8 +37,7 @@ class UpdaterTreeStatTest : public ::testing::Test { : CreateEmptyGenericParam(Context::kCpuId)); auto up = std::unique_ptr{TreeUpdater::Create(updater, &ctx, &task)}; up->Configure(Args{}); - RegTree tree; - tree.param.num_feature = kCols; + RegTree tree{1u, kCols}; std::vector> position(1); up->Update(¶m, &gpairs_, p_dmat_.get(), position, {&tree}); @@ -95,16 +94,14 @@ class UpdaterEtaTest : public ::testing::Test { param1.Init(Args{{"eta", "1.0"}}); for (size_t iter = 0; iter < 4; ++iter) { - RegTree tree_0; + RegTree tree_0{1u, kCols}; { - tree_0.param.num_feature = kCols; std::vector> position(1); up_0->Update(¶m0, &gpairs_, p_dmat_.get(), position, {&tree_0}); } - RegTree tree_1; + RegTree tree_1{1u, kCols}; { - tree_1.param.num_feature = kCols; std::vector> position(1); up_1->Update(¶m1, &gpairs_, p_dmat_.get(), position, {&tree_1}); }