Don't use 0 for "fresh leaf". (#5084)
* Allow using right child as marker for Exact tree_method.
This commit is contained in:
parent
df9bdbbcb9
commit
f3d8536702
@ -136,7 +136,7 @@ class RegTree : public Model {
|
|||||||
}
|
}
|
||||||
/*! \brief whether current node is leaf node */
|
/*! \brief whether current node is leaf node */
|
||||||
XGBOOST_DEVICE bool IsLeaf() const {
|
XGBOOST_DEVICE bool IsLeaf() const {
|
||||||
return cleft_ == -1;
|
return cleft_ == kInvalidNodeId;
|
||||||
}
|
}
|
||||||
/*! \return get leaf value of leaf node */
|
/*! \return get leaf value of leaf node */
|
||||||
XGBOOST_DEVICE bst_float LeafValue() const {
|
XGBOOST_DEVICE bst_float LeafValue() const {
|
||||||
@ -159,7 +159,7 @@ class RegTree : public Model {
|
|||||||
return sindex_ == std::numeric_limits<unsigned>::max();
|
return sindex_ == std::numeric_limits<unsigned>::max();
|
||||||
}
|
}
|
||||||
/*! \brief whether current node is root */
|
/*! \brief whether current node is root */
|
||||||
XGBOOST_DEVICE bool IsRoot() const { return parent_ == -1; }
|
XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
||||||
/*!
|
/*!
|
||||||
* \brief set the left child
|
* \brief set the left child
|
||||||
* \param nid node id to right child
|
* \param nid node id to right child
|
||||||
@ -192,9 +192,9 @@ class RegTree : public Model {
|
|||||||
* \param right right index, could be used to store
|
* \param right right index, could be used to store
|
||||||
* additional information
|
* additional information
|
||||||
*/
|
*/
|
||||||
XGBOOST_DEVICE void SetLeaf(bst_float value, int right = -1) {
|
XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
|
||||||
(this->info_).leaf_value = value;
|
(this->info_).leaf_value = value;
|
||||||
this->cleft_ = -1;
|
this->cleft_ = kInvalidNodeId;
|
||||||
this->cright_ = right;
|
this->cright_ = right;
|
||||||
}
|
}
|
||||||
/*! \brief mark that this node is deleted */
|
/*! \brief mark that this node is deleted */
|
||||||
@ -275,7 +275,7 @@ class RegTree : public Model {
|
|||||||
stats_.resize(param.num_nodes);
|
stats_.resize(param.num_nodes);
|
||||||
for (int i = 0; i < param.num_nodes; i ++) {
|
for (int i = 0; i < param.num_nodes; i ++) {
|
||||||
nodes_[i].SetLeaf(0.0f);
|
nodes_[i].SetLeaf(0.0f);
|
||||||
nodes_[i].SetParent(-1);
|
nodes_[i].SetParent(kInvalidNodeId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*! \brief get node given nid */
|
/*! \brief get node given nid */
|
||||||
@ -327,11 +327,14 @@ class RegTree : public Model {
|
|||||||
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
|
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
|
||||||
* \param loss_change The loss change.
|
* \param loss_change The loss change.
|
||||||
* \param sum_hess The sum hess.
|
* \param sum_hess The sum hess.
|
||||||
|
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
|
||||||
|
* some updaters use the right child index of leaf as a marker
|
||||||
*/
|
*/
|
||||||
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
|
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
|
||||||
bool default_left, bst_float base_weight,
|
bool default_left, bst_float base_weight,
|
||||||
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
||||||
bst_float loss_change, float sum_hess) {
|
bst_float loss_change, float sum_hess,
|
||||||
|
bst_node_t leaf_right_child = kInvalidNodeId) {
|
||||||
int pleft = this->AllocNode();
|
int pleft = this->AllocNode();
|
||||||
int pright = this->AllocNode();
|
int pright = this->AllocNode();
|
||||||
auto &node = nodes_[nid];
|
auto &node = nodes_[nid];
|
||||||
@ -342,9 +345,9 @@ class RegTree : public Model {
|
|||||||
nodes_[node.RightChild()].SetParent(nid, false);
|
nodes_[node.RightChild()].SetParent(nid, false);
|
||||||
node.SetSplit(split_index, split_value,
|
node.SetSplit(split_index, split_value,
|
||||||
default_left);
|
default_left);
|
||||||
// mark right child as 0, to indicate fresh leaf
|
|
||||||
nodes_[pleft].SetLeaf(left_leaf_weight, 0);
|
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
|
||||||
nodes_[pright].SetLeaf(right_leaf_weight, 0);
|
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
|
||||||
|
|
||||||
this->Stat(nid).loss_chg = loss_change;
|
this->Stat(nid).loss_chg = loss_change;
|
||||||
this->Stat(nid).base_weight = base_weight;
|
this->Stat(nid).base_weight = base_weight;
|
||||||
|
|||||||
@ -594,7 +594,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||||
right_leaf_weight, e.best.loss_chg,
|
right_leaf_weight, e.best.loss_chg,
|
||||||
e.stats.sum_hess);
|
e.stats.sum_hess, 0);
|
||||||
} else {
|
} else {
|
||||||
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
|
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user