From f75a21af25d1752648d78d8efae33b6cb1a7fac9 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 20 Dec 2018 15:52:28 +1300 Subject: [PATCH] Reduce tree expand boilerplate code (#4008) --- include/xgboost/tree_model.h | 29 +++++++++++++++++++++-------- src/tree/updater_colmaker.cc | 6 +----- src/tree/updater_gpu_common.cuh | 5 +---- src/tree/updater_gpu_hist.cu | 5 ++--- src/tree/updater_histmaker.cc | 8 ++------ src/tree/updater_quantile_hist.cc | 10 ++-------- src/tree/updater_skmaker.cc | 8 ++------ tests/cpp/tree/test_prune.cc | 4 ++-- tests/cpp/tree/test_refresh.cc | 2 +- 9 files changed, 34 insertions(+), 43 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 56e2820e8..da5c1d0f4 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -301,18 +301,31 @@ class RegTree { fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()); fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size()); } - /*! - * \brief add child nodes to node - * \param nid node id to add children to + + /** + * \brief Expands a leaf node into two additional leaf nodes + * + * \param nid The node index to expand. + * \param split_index Feature index of the split. + * \param split_value The split condition. + * \param default_left True to default left. */ - void AddChilds(int nid) { + void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left) { int pleft = this->AllocNode(); int pright = this->AllocNode(); - nodes_[nid].SetLeftChild(pleft); - nodes_[nid].SetRightChild(pright); - nodes_[nodes_[nid].LeftChild() ].SetParent(nid, true); - nodes_[nodes_[nid].RightChild()].SetParent(nid, false); + auto &node = nodes_[nid]; + CHECK(node.IsLeaf()); + node.SetLeftChild(pleft); + node.SetRightChild(pright); + nodes_[node.LeftChild()].SetParent(nid, true); + nodes_[node.RightChild()].SetParent(nid, false); + node.SetSplit(split_index, split_value, + default_left); + // mark right child as 0, to indicate fresh leaf + nodes_[pleft].SetLeaf(0.0f, 0); + nodes_[pright].SetLeaf(0.0f, 0); } + /*! * \brief get current depth * \param nid node id diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 99e8842c0..5a3a7d1f3 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -637,11 +637,7 @@ class ColMaker: public TreeUpdater { NodeEntry &e = snode_[nid]; // now we know the solution in snode[nid], set split if (e.best.loss_chg > kRtEps) { - p_tree->AddChilds(nid); - (*p_tree)[nid].SetSplit(e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft()); - // mark right child as 0, to indicate fresh leaf - (*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0); - (*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0); + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft()); } else { (*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate); } diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 94b52e971..ded04b1c5 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -296,13 +296,10 @@ inline void Dense2SparseTree(RegTree* p_tree, for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) { const DeviceNodeStats& n = h_nodes[gpu_nid]; if (!n.IsUnused() && !n.IsLeaf()) { - tree.AddChilds(nid); - tree[nid].SetSplit(n.fidx, n.fvalue, n.dir == kLeftDir); + tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir); tree.Stat(nid).loss_chg = n.root_gain; tree.Stat(nid).base_weight = n.weight; tree.Stat(nid).sum_hess = n.sum_gradients.GetHess(); - tree[tree[nid].LeftChild()].SetLeaf(0); - tree[tree[nid].RightChild()].SetLeaf(0); nid++; } else if (n.IsLeaf()) { tree[nid].SetLeaf(n.weight * param.learning_rate); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 09f33a014..1c531423a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1184,10 +1184,9 @@ class GPUHistMakerSpecialised{ void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { // Add new leaves RegTree& tree = *p_tree; - tree.AddChilds(candidate.nid); - auto& parent = tree[candidate.nid]; - parent.SetSplit(candidate.split.findex, candidate.split.fvalue, + tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue, candidate.split.dir == kLeftDir); + auto& parent = tree[candidate.nid]; tree.Stat(candidate.nid).loss_chg = candidate.split.loss_chg; // Set up child constraints diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 936ccd498..448f1cb92 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -243,12 +243,8 @@ class HistMaker: public BaseMaker { p_tree->Stat(nid).loss_chg = best.loss_chg; // now we know the solution in snode[nid], set split if (best.loss_chg > kRtEps) { - p_tree->AddChilds(nid); - (*p_tree)[nid].SetSplit(best.SplitIndex(), - best.split_value, best.DefaultLeft()); - // mark right child as 0, to indicate fresh leaf - (*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0); - (*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0); + p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value, + best.DefaultLeft()); // right side sum TStats right_sum; right_sum.SetSubstract(node_sum, left_sum[wid]); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index b55520cf1..87d41aacf 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -429,14 +429,8 @@ void QuantileHistMaker::Builder::ApplySplit(int nid, /* 1. Create child nodes */ NodeEntry& e = snode_[nid]; - - p_tree->AddChilds(nid); - (*p_tree)[nid].SetSplit(e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft()); - // mark right child as 0, to indicate fresh leaf - int cleft = (*p_tree)[nid].LeftChild(); - int cright = (*p_tree)[nid].RightChild(); - (*p_tree)[cleft].SetLeaf(0.0f, 0); - (*p_tree)[cright].SetLeaf(0.0f, 0); + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft()); /* 2. Categorize member rows */ const auto nthread = static_cast(this->nthread_); diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index 405d1c2bf..c19a6a1d1 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -285,12 +285,8 @@ class SketchMaker: public BaseMaker { this->SetStats(nid, node_stats_[nid], p_tree); // now we know the solution in snode[nid], set split if (best.loss_chg > kRtEps) { - p_tree->AddChilds(nid); - (*p_tree)[nid].SetSplit(best.SplitIndex(), - best.split_value, best.DefaultLeft()); - // mark right child as 0, to indicate fresh leaf - (*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0); - (*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0); + p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value, + best.DefaultLeft()); } else { (*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate); } diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index fbebf47b7..82ffdb3c7 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -38,7 +38,7 @@ TEST(Updater, Prune) { pruner->Init(cfg); // loss_chg < min_split_loss; - tree.AddChilds(0); + tree.ExpandNode(0, 0, 0, true); int cleft = tree[0].LeftChild(); int cright = tree[0].RightChild(); tree[cleft].SetLeaf(0.3f, 0); @@ -48,7 +48,7 @@ TEST(Updater, Prune) { ASSERT_EQ(tree.NumExtraNodes(), 0); // loss_chg > min_split_loss; - tree.AddChilds(0); + tree.ExpandNode(0, 0, 0, true); cleft = tree[0].LeftChild(); cright = tree[0].RightChild(); tree[cleft].SetLeaf(0.3f, 0); diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index d1e66edb1..e9cf565aa 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -29,7 +29,7 @@ TEST(Updater, Refresh) { std::vector trees {&tree}; std::unique_ptr refresher(TreeUpdater::Create("refresh")); - tree.AddChilds(0); + tree.ExpandNode(0, 0, 0, true); int cleft = tree[0].LeftChild(); int cright = tree[0].RightChild(); tree[cleft].SetLeaf(0.2f, 0);