Don't use 0 for "fresh leaf". (#5084)

* Allow using right child as marker for Exact tree_method.
This commit is contained in:
Jiaming Yuan 2019-12-05 11:50:51 +08:00 committed by GitHub
parent df9bdbbcb9
commit f3d8536702
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 10 deletions

View File

@ -136,7 +136,7 @@ class RegTree : public Model {
}
/*! \brief whether current node is leaf node */
XGBOOST_DEVICE bool IsLeaf() const {
return cleft_ == -1;
return cleft_ == kInvalidNodeId;
}
/*! \return get leaf value of leaf node */
XGBOOST_DEVICE bst_float LeafValue() const {
@ -159,7 +159,7 @@ class RegTree : public Model {
return sindex_ == std::numeric_limits<unsigned>::max();
}
/*! \brief whether current node is root */
XGBOOST_DEVICE bool IsRoot() const { return parent_ == -1; }
XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
/*!
* \brief set the left child
* \param nid node id to right child
@ -192,9 +192,9 @@ class RegTree : public Model {
* \param right right index, could be used to store
* additional information
*/
XGBOOST_DEVICE void SetLeaf(bst_float value, int right = -1) {
XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
(this->info_).leaf_value = value;
this->cleft_ = -1;
this->cleft_ = kInvalidNodeId;
this->cright_ = right;
}
/*! \brief mark that this node is deleted */
@ -275,7 +275,7 @@ class RegTree : public Model {
stats_.resize(param.num_nodes);
for (int i = 0; i < param.num_nodes; i ++) {
nodes_[i].SetLeaf(0.0f);
nodes_[i].SetParent(-1);
nodes_[i].SetParent(kInvalidNodeId);
}
}
/*! \brief get node given nid */
@ -327,11 +327,14 @@ class RegTree : public Model {
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
* \param loss_change The loss change.
* \param sum_hess The sum hess.
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
* some updaters use the right child index of leaf as a marker
*/
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight,
bst_float left_leaf_weight, bst_float right_leaf_weight,
bst_float loss_change, float sum_hess) {
bst_float loss_change, float sum_hess,
bst_node_t leaf_right_child = kInvalidNodeId) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
auto &node = nodes_[nid];
@ -342,9 +345,9 @@ class RegTree : public Model {
nodes_[node.RightChild()].SetParent(nid, false);
node.SetSplit(split_index, split_value,
default_left);
// mark right child as 0, to indicate fresh leaf
nodes_[pleft].SetLeaf(left_leaf_weight, 0);
nodes_[pright].SetLeaf(right_leaf_weight, 0);
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
this->Stat(nid).loss_chg = loss_change;
this->Stat(nid).base_weight = base_weight;

View File

@ -594,7 +594,7 @@ class ColMaker: public TreeUpdater {
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg,
e.stats.sum_hess);
e.stats.sum_hess, 0);
} else {
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
}