From f3d85367028c4f737bbe01453c097b269ec8f4d5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 5 Dec 2019 11:50:51 +0800 Subject: [PATCH] Don't use 0 for "fresh leaf". (#5084) * Allow using right child as marker for Exact tree_method. --- include/xgboost/tree_model.h | 21 ++++++++++++--------- src/tree/updater_colmaker.cc | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 4b1f8b834..35ab56079 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -136,7 +136,7 @@ class RegTree : public Model { } /*! \brief whether current node is leaf node */ XGBOOST_DEVICE bool IsLeaf() const { - return cleft_ == -1; + return cleft_ == kInvalidNodeId; } /*! \return get leaf value of leaf node */ XGBOOST_DEVICE bst_float LeafValue() const { @@ -159,7 +159,7 @@ class RegTree : public Model { return sindex_ == std::numeric_limits::max(); } /*! \brief whether current node is root */ - XGBOOST_DEVICE bool IsRoot() const { return parent_ == -1; } + XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; } /*! * \brief set the left child * \param nid node id to right child @@ -192,9 +192,9 @@ class RegTree : public Model { * \param right right index, could be used to store * additional information */ - XGBOOST_DEVICE void SetLeaf(bst_float value, int right = -1) { + XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) { (this->info_).leaf_value = value; - this->cleft_ = -1; + this->cleft_ = kInvalidNodeId; this->cright_ = right; } /*! \brief mark that this node is deleted */ @@ -275,7 +275,7 @@ class RegTree : public Model { stats_.resize(param.num_nodes); for (int i = 0; i < param.num_nodes; i ++) { nodes_[i].SetLeaf(0.0f); - nodes_[i].SetParent(-1); + nodes_[i].SetParent(kInvalidNodeId); } } /*! \brief get node given nid */ @@ -327,11 +327,14 @@ class RegTree : public Model { * \param right_leaf_weight The right leaf weight for prediction, modified by learning rate. * \param loss_change The loss change. * \param sum_hess The sum hess. + * \param leaf_right_child The right child index of leaf, by default kInvalidNodeId, + * some updaters use the right child index of leaf as a marker */ void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, - bst_float loss_change, float sum_hess) { + bst_float loss_change, float sum_hess, + bst_node_t leaf_right_child = kInvalidNodeId) { int pleft = this->AllocNode(); int pright = this->AllocNode(); auto &node = nodes_[nid]; @@ -342,9 +345,9 @@ class RegTree : public Model { nodes_[node.RightChild()].SetParent(nid, false); node.SetSplit(split_index, split_value, default_left); - // mark right child as 0, to indicate fresh leaf - nodes_[pleft].SetLeaf(left_leaf_weight, 0); - nodes_[pright].SetLeaf(right_leaf_weight, 0); + + nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child); + nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child); this->Stat(nid).loss_chg = loss_change; this->Stat(nid).base_weight = base_weight; diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 3039fdd77..dbd6cc6e9 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -594,7 +594,7 @@ class ColMaker: public TreeUpdater { p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft(), e.weight, left_leaf_weight, right_leaf_weight, e.best.loss_chg, - e.stats.sum_hess); + e.stats.sum_hess, 0); } else { (*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate); }