Don't use 0 for "fresh leaf". (#5084)

* Allow using right child as marker for Exact tree_method.
This commit is contained in:
Jiaming Yuan 2019-12-05 11:50:51 +08:00 committed by GitHub
parent df9bdbbcb9
commit f3d8536702
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 10 deletions

View File

@ -136,7 +136,7 @@ class RegTree : public Model {
} }
/*! \brief whether current node is leaf node */ /*! \brief whether current node is leaf node */
XGBOOST_DEVICE bool IsLeaf() const { XGBOOST_DEVICE bool IsLeaf() const {
return cleft_ == -1; return cleft_ == kInvalidNodeId;
} }
/*! \return get leaf value of leaf node */ /*! \return get leaf value of leaf node */
XGBOOST_DEVICE bst_float LeafValue() const { XGBOOST_DEVICE bst_float LeafValue() const {
@ -159,7 +159,7 @@ class RegTree : public Model {
return sindex_ == std::numeric_limits<unsigned>::max(); return sindex_ == std::numeric_limits<unsigned>::max();
} }
/*! \brief whether current node is root */ /*! \brief whether current node is root */
XGBOOST_DEVICE bool IsRoot() const { return parent_ == -1; } XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
/*! /*!
* \brief set the left child * \brief set the left child
* \param nid node id to right child * \param nid node id to right child
@ -192,9 +192,9 @@ class RegTree : public Model {
* \param right right index, could be used to store * \param right right index, could be used to store
* additional information * additional information
*/ */
XGBOOST_DEVICE void SetLeaf(bst_float value, int right = -1) { XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
(this->info_).leaf_value = value; (this->info_).leaf_value = value;
this->cleft_ = -1; this->cleft_ = kInvalidNodeId;
this->cright_ = right; this->cright_ = right;
} }
/*! \brief mark that this node is deleted */ /*! \brief mark that this node is deleted */
@ -275,7 +275,7 @@ class RegTree : public Model {
stats_.resize(param.num_nodes); stats_.resize(param.num_nodes);
for (int i = 0; i < param.num_nodes; i ++) { for (int i = 0; i < param.num_nodes; i ++) {
nodes_[i].SetLeaf(0.0f); nodes_[i].SetLeaf(0.0f);
nodes_[i].SetParent(-1); nodes_[i].SetParent(kInvalidNodeId);
} }
} }
/*! \brief get node given nid */ /*! \brief get node given nid */
@ -327,11 +327,14 @@ class RegTree : public Model {
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate. * \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
* \param loss_change The loss change. * \param loss_change The loss change.
* \param sum_hess The sum hess. * \param sum_hess The sum hess.
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
* some updaters use the right child index of leaf as a marker
*/ */
void ExpandNode(int nid, unsigned split_index, bst_float split_value, void ExpandNode(int nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight, bool default_left, bst_float base_weight,
bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float left_leaf_weight, bst_float right_leaf_weight,
bst_float loss_change, float sum_hess) { bst_float loss_change, float sum_hess,
bst_node_t leaf_right_child = kInvalidNodeId) {
int pleft = this->AllocNode(); int pleft = this->AllocNode();
int pright = this->AllocNode(); int pright = this->AllocNode();
auto &node = nodes_[nid]; auto &node = nodes_[nid];
@ -342,9 +345,9 @@ class RegTree : public Model {
nodes_[node.RightChild()].SetParent(nid, false); nodes_[node.RightChild()].SetParent(nid, false);
node.SetSplit(split_index, split_value, node.SetSplit(split_index, split_value,
default_left); default_left);
// mark right child as 0, to indicate fresh leaf
nodes_[pleft].SetLeaf(left_leaf_weight, 0); nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
nodes_[pright].SetLeaf(right_leaf_weight, 0); nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
this->Stat(nid).loss_chg = loss_change; this->Stat(nid).loss_chg = loss_change;
this->Stat(nid).base_weight = base_weight; this->Stat(nid).base_weight = base_weight;

View File

@ -594,7 +594,7 @@ class ColMaker: public TreeUpdater {
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight, e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, right_leaf_weight, e.best.loss_chg,
e.stats.sum_hess); e.stats.sum_hess, 0);
} else { } else {
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate); (*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
} }