Remove public access to tree model param. (#8902)
* Make tree model param a private member. * Number of features and targets are immutable after construction. This is to reduce the number of places where we can run configuration.
This commit is contained in:
@@ -178,51 +178,33 @@ class RegTree : public Model {
|
||||
}
|
||||
|
||||
/*! \brief index of left child */
|
||||
XGBOOST_DEVICE [[nodiscard]] int LeftChild() const {
|
||||
return this->cleft_;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
|
||||
/*! \brief index of right child */
|
||||
XGBOOST_DEVICE [[nodiscard]] int RightChild() const {
|
||||
return this->cright_;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
|
||||
/*! \brief index of default child when feature is missing */
|
||||
XGBOOST_DEVICE [[nodiscard]] int DefaultChild() const {
|
||||
[[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
|
||||
return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
|
||||
}
|
||||
/*! \brief feature index of split condition */
|
||||
XGBOOST_DEVICE [[nodiscard]] unsigned SplitIndex() const {
|
||||
[[nodiscard]] XGBOOST_DEVICE unsigned SplitIndex() const {
|
||||
return sindex_ & ((1U << 31) - 1U);
|
||||
}
|
||||
/*! \brief when feature is unknown, whether goes to left child */
|
||||
XGBOOST_DEVICE [[nodiscard]] bool DefaultLeft() const {
|
||||
return (sindex_ >> 31) != 0;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
|
||||
/*! \brief whether current node is leaf node */
|
||||
XGBOOST_DEVICE [[nodiscard]] bool IsLeaf() const {
|
||||
return cleft_ == kInvalidNodeId;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
|
||||
/*! \return get leaf value of leaf node */
|
||||
XGBOOST_DEVICE [[nodiscard]] float LeafValue() const {
|
||||
return (this->info_).leaf_value;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
|
||||
/*! \return get split condition of the node */
|
||||
XGBOOST_DEVICE [[nodiscard]] SplitCondT SplitCond() const {
|
||||
return (this->info_).split_cond;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
|
||||
/*! \brief get parent of the node */
|
||||
XGBOOST_DEVICE [[nodiscard]] int Parent() const {
|
||||
return parent_ & ((1U << 31) - 1);
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
|
||||
/*! \brief whether current node is left child */
|
||||
XGBOOST_DEVICE [[nodiscard]] bool IsLeftChild() const {
|
||||
return (parent_ & (1U << 31)) != 0;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
|
||||
/*! \brief whether this node is deleted */
|
||||
XGBOOST_DEVICE [[nodiscard]] bool IsDeleted() const {
|
||||
return sindex_ == kDeletedNodeMarker;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
|
||||
/*! \brief whether current node is root */
|
||||
XGBOOST_DEVICE [[nodiscard]] bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
||||
[[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
||||
/*!
|
||||
* \brief set the left child
|
||||
* \param nid node id to right child
|
||||
@@ -337,15 +319,13 @@ class RegTree : public Model {
|
||||
this->ChangeToLeaf(rid, value);
|
||||
}
|
||||
|
||||
/*! \brief model parameter */
|
||||
TreeParam param;
|
||||
RegTree() {
|
||||
param.Init(Args{});
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param.num_nodes);
|
||||
for (int i = 0; i < param.num_nodes; i++) {
|
||||
param_.Init(Args{});
|
||||
nodes_.resize(param_.num_nodes);
|
||||
stats_.resize(param_.num_nodes);
|
||||
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param_.num_nodes);
|
||||
for (int i = 0; i < param_.num_nodes; i++) {
|
||||
nodes_[i].SetLeaf(0.0f);
|
||||
nodes_[i].SetParent(kInvalidNodeId);
|
||||
}
|
||||
@@ -354,10 +334,10 @@ class RegTree : public Model {
|
||||
* \brief Constructor that initializes the tree model with shape.
|
||||
*/
|
||||
explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
|
||||
param.num_feature = n_features;
|
||||
param.size_leaf_vector = n_targets;
|
||||
param_.num_feature = n_features;
|
||||
param_.size_leaf_vector = n_targets;
|
||||
if (n_targets > 1) {
|
||||
this->p_mt_tree_.reset(new MultiTargetTree{¶m});
|
||||
this->p_mt_tree_.reset(new MultiTargetTree{¶m_});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -401,7 +381,7 @@ class RegTree : public Model {
|
||||
|
||||
bool operator==(const RegTree& b) const {
|
||||
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
|
||||
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
|
||||
deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
|
||||
}
|
||||
/* \brief Iterate through all nodes in this tree.
|
||||
*
|
||||
@@ -459,7 +439,9 @@ class RegTree : public Model {
|
||||
bst_float loss_change, float sum_hess, float left_sum,
|
||||
float right_sum,
|
||||
bst_node_t leaf_right_child = kInvalidNodeId);
|
||||
|
||||
/**
|
||||
* \brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
|
||||
*/
|
||||
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
|
||||
linalg::VectorView<float const> base_weight,
|
||||
linalg::VectorView<float const> left_weight,
|
||||
@@ -485,19 +467,48 @@ class RegTree : public Model {
|
||||
bst_float base_weight, bst_float left_leaf_weight,
|
||||
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||
float left_sum, float right_sum);
|
||||
|
||||
[[nodiscard]] bool HasCategoricalSplit() const {
|
||||
return !split_categories_.empty();
|
||||
}
|
||||
/**
|
||||
* \brief Whether this tree has categorical split.
|
||||
*/
|
||||
[[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
|
||||
/**
|
||||
* \brief Whether this is a multi-target tree.
|
||||
*/
|
||||
[[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
|
||||
[[nodiscard]] bst_target_t NumTargets() const { return param.size_leaf_vector; }
|
||||
/**
|
||||
* \brief The size of leaf weight.
|
||||
*/
|
||||
[[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
|
||||
/**
|
||||
* \brief Get the underlying implementaiton of multi-target tree.
|
||||
*/
|
||||
[[nodiscard]] auto GetMultiTargetTree() const {
|
||||
CHECK(IsMultiTarget());
|
||||
return p_mt_tree_.get();
|
||||
}
|
||||
/**
|
||||
* \brief Get the number of features.
|
||||
*/
|
||||
[[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
|
||||
/**
|
||||
* \brief Get the total number of nodes including deleted ones in this tree.
|
||||
*/
|
||||
[[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
|
||||
/**
|
||||
* \brief Get the total number of valid nodes in this tree.
|
||||
*/
|
||||
[[nodiscard]] bst_node_t NumValidNodes() const noexcept {
|
||||
return param_.num_nodes - param_.num_deleted;
|
||||
}
|
||||
/**
|
||||
* \brief number of extra nodes besides the root
|
||||
*/
|
||||
[[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
|
||||
return param_.num_nodes - 1 - param_.num_deleted;
|
||||
}
|
||||
/* \brief Count number of leaves in tree. */
|
||||
[[nodiscard]] bst_node_t GetNumLeaves() const;
|
||||
[[nodiscard]] bst_node_t GetNumSplitNodes() const;
|
||||
|
||||
/*!
|
||||
* \brief get current depth
|
||||
@@ -514,6 +525,9 @@ class RegTree : public Model {
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
/**
|
||||
* \brief Set the leaf weight for a multi-target tree.
|
||||
*/
|
||||
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
|
||||
CHECK(IsMultiTarget());
|
||||
return this->p_mt_tree_->SetLeaf(nidx, weight);
|
||||
@@ -525,25 +539,13 @@ class RegTree : public Model {
|
||||
*/
|
||||
[[nodiscard]] int MaxDepth(int nid) const {
|
||||
if (nodes_[nid].IsLeaf()) return 0;
|
||||
return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
|
||||
MaxDepth(nodes_[nid].RightChild())+1);
|
||||
return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief get maximum depth
|
||||
*/
|
||||
int MaxDepth() {
|
||||
return MaxDepth(0);
|
||||
}
|
||||
|
||||
/*! \brief number of extra nodes besides the root */
|
||||
[[nodiscard]] int NumExtraNodes() const {
|
||||
return param.num_nodes - 1 - param.num_deleted;
|
||||
}
|
||||
|
||||
/* \brief Count number of leaves in tree. */
|
||||
[[nodiscard]] bst_node_t GetNumLeaves() const;
|
||||
[[nodiscard]] bst_node_t GetNumSplitNodes() const;
|
||||
int MaxDepth() { return MaxDepth(0); }
|
||||
|
||||
/*!
|
||||
* \brief dense feature vector that can be taken by RegTree
|
||||
@@ -735,6 +737,8 @@ class RegTree : public Model {
|
||||
template <bool typed>
|
||||
void LoadCategoricalSplit(Json const& in);
|
||||
void SaveCategoricalSplit(Json* p_out) const;
|
||||
/*! \brief model parameter */
|
||||
TreeParam param_;
|
||||
// vector of nodes
|
||||
std::vector<Node> nodes_;
|
||||
// free node space, used during training process
|
||||
@@ -752,20 +756,20 @@ class RegTree : public Model {
|
||||
// allocate a new node,
|
||||
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
||||
bst_node_t AllocNode() {
|
||||
if (param.num_deleted != 0) {
|
||||
if (param_.num_deleted != 0) {
|
||||
int nid = deleted_nodes_.back();
|
||||
deleted_nodes_.pop_back();
|
||||
nodes_[nid].Reuse();
|
||||
--param.num_deleted;
|
||||
--param_.num_deleted;
|
||||
return nid;
|
||||
}
|
||||
int nd = param.num_nodes++;
|
||||
CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
|
||||
int nd = param_.num_nodes++;
|
||||
CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
|
||||
<< "number of nodes in the tree exceed 2^31";
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param.num_nodes);
|
||||
nodes_.resize(param_.num_nodes);
|
||||
stats_.resize(param_.num_nodes);
|
||||
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param_.num_nodes);
|
||||
return nd;
|
||||
}
|
||||
// delete a tree node, keep the parent field to allow trace back
|
||||
@@ -780,7 +784,7 @@ class RegTree : public Model {
|
||||
|
||||
deleted_nodes_.push_back(nid);
|
||||
nodes_[nid].MarkDelete();
|
||||
++param.num_deleted;
|
||||
++param_.num_deleted;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user