From 1fc37e47497239a7bc52e0e9949004c891fe36b6 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Fri, 18 Jan 2019 07:12:20 +0200 Subject: [PATCH] Require leaf statistics when expanding tree (#4015) * Cache left and right gradient sums * Require leaf statistics when expanding tree --- include/xgboost/tree_model.h | 28 ++++++++++----- src/tree/param.h | 13 +++++-- src/tree/updater_colmaker.cc | 58 +++++++++++++++++++++---------- src/tree/updater_gpu_common.cuh | 3 +- src/tree/updater_gpu_hist.cu | 45 ++++++++++-------------- src/tree/updater_histmaker.cc | 17 +++++++-- src/tree/updater_quantile_hist.cc | 10 ++++-- src/tree/updater_skmaker.cc | 27 ++++++++++---- tests/cpp/tree/test_param.cc | 9 +++-- tests/cpp/tree/test_prune.cc | 13 ++----- tests/cpp/tree/test_refresh.cc | 5 +-- 11 files changed, 143 insertions(+), 85 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index da5c1d0f4..4fa2ccad8 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -303,14 +303,22 @@ class RegTree { } /** - * \brief Expands a leaf node into two additional leaf nodes + * \brief Expands a leaf node into two additional leaf nodes. * - * \param nid The node index to expand. - * \param split_index Feature index of the split. - * \param split_value The split condition. - * \param default_left True to default left. + * \param nid The node index to expand. + * \param split_index Feature index of the split. + * \param split_value The split condition. + * \param default_left True to default left. + * \param base_weight The base weight, before learning rate. + * \param left_leaf_weight The left leaf weight for prediction, modified by learning rate. + * \param right_leaf_weight The right leaf weight for prediction, modified by learning rate. + * \param loss_change The loss change. + * \param sum_hess The sum hess. */ - void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left) { + void ExpandNode(int nid, unsigned split_index, bst_float split_value, + bool default_left, bst_float base_weight, + bst_float left_leaf_weight, bst_float right_leaf_weight, + bst_float loss_change, float sum_hess) { int pleft = this->AllocNode(); int pright = this->AllocNode(); auto &node = nodes_[nid]; @@ -322,8 +330,12 @@ class RegTree { node.SetSplit(split_index, split_value, default_left); // mark right child as 0, to indicate fresh leaf - nodes_[pleft].SetLeaf(0.0f, 0); - nodes_[pright].SetLeaf(0.0f, 0); + nodes_[pleft].SetLeaf(left_leaf_weight, 0); + nodes_[pright].SetLeaf(right_leaf_weight, 0); + + this->Stat(nid).loss_chg = loss_change; + this->Stat(nid).base_weight = base_weight; + this->Stat(nid).sum_hess = sum_hess; } /*! diff --git a/src/tree/param.h b/src/tree/param.h index 073f36b1b..c55543a79 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -354,6 +354,8 @@ struct XGBOOST_ALIGNAS(16) GradStats { static const int kSimpleStats = 1; /*! \brief constructor, the object must be cleared during construction */ explicit GradStats(const TrainParam& param) { this->Clear(); } + explicit GradStats(double sum_grad, double sum_hess) + : sum_grad(sum_grad), sum_hess(sum_hess) {} template XGBOOST_DEVICE explicit GradStats(const GpairT &sum) @@ -490,8 +492,10 @@ struct SplitEntry { bst_float loss_chg{0.0f}; /*! \brief split index */ unsigned sindex{0}; - /*! \brief split value */ bst_float split_value{0.0f}; + GradStats left_sum; + GradStats right_sum; + /*! \brief constructor */ SplitEntry() = default; /*! @@ -521,6 +525,8 @@ struct SplitEntry { this->loss_chg = e.loss_chg; this->sindex = e.sindex; this->split_value = e.split_value; + this->left_sum = e.left_sum; + this->right_sum = e.right_sum; return true; } else { return false; @@ -535,7 +541,8 @@ struct SplitEntry { * \return whether the proposed split is better and can replace current split */ inline bool Update(bst_float new_loss_chg, unsigned split_index, - bst_float new_split_value, bool default_left) { + bst_float new_split_value, bool default_left, + const GradStats &left_sum, const GradStats &right_sum) { if (this->NeedReplace(new_loss_chg, split_index)) { this->loss_chg = new_loss_chg; if (default_left) { @@ -543,6 +550,8 @@ struct SplitEntry { } this->sindex = split_index; this->split_value = new_split_value; + this->left_sum = left_sum; + this->right_sum = right_sum; return true; } else { return false; diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 5a3a7d1f3..d03fdcefb 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -311,7 +311,7 @@ class ColMaker: public TreeUpdater { auto loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, e.stats, c) - snode_[nid].root_gain); - e.best.Update(loss_chg, fid, fsplit, false); + e.best.Update(loss_chg, fid, fsplit, false, e.stats, c); } } if (need_backward) { @@ -322,7 +322,7 @@ class ColMaker: public TreeUpdater { auto loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, tmp, c) - snode_[nid].root_gain); - e.best.Update(loss_chg, fid, fsplit, true); + e.best.Update(loss_chg, fid, fsplit, true, tmp, c); } } } @@ -335,7 +335,7 @@ class ColMaker: public TreeUpdater { auto loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, tmp, c) - snode_[nid].root_gain); - e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true); + e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true, tmp, c); } } } @@ -368,7 +368,7 @@ class ColMaker: public TreeUpdater { spliteval_->ComputeSplitScore(nid, fid, e.stats, c) - snode_[nid].root_gain); e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, - false); + false, e.stats, c); } } if (need_backward) { @@ -379,7 +379,7 @@ class ColMaker: public TreeUpdater { auto loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, c, cright) - snode_[nid].root_gain); - e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true); + e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true, c, cright); } } } @@ -410,13 +410,15 @@ class ColMaker: public TreeUpdater { loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, c, e.stats) - snode_[nid].root_gain); + e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, + d_step == -1, c, e.stats); } else { loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, e.stats, c) - snode_[nid].root_gain); + e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, + d_step == -1, e.stats, c); } - e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, - d_step == -1); } } // update the statistics @@ -486,18 +488,21 @@ class ColMaker: public TreeUpdater { if (e.stats.sum_hess >= param_.min_child_weight && c.sum_hess >= param_.min_child_weight) { bst_float loss_chg; + const bst_float gap = std::abs(e.last_fvalue) + kRtEps; + const bst_float delta = d_step == +1 ? gap: -gap; if (d_step == -1) { loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, c, e.stats) - snode_[nid].root_gain); + e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, c, + e.stats); } else { loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, e.stats, c) - snode_[nid].root_gain); + e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, + e.stats, c); } - const bst_float gap = std::abs(e.last_fvalue) + kRtEps; - const bst_float delta = d_step == +1 ? gap: -gap; - e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1); } } } @@ -545,12 +550,15 @@ class ColMaker: public TreeUpdater { loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, c, e.stats) - snode_[nid].root_gain); + e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, + d_step == -1, c, e.stats); } else { loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, e.stats, c) - snode_[nid].root_gain); + e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, + d_step == -1, e.stats, c); } - e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1); } } // update the statistics @@ -565,18 +573,21 @@ class ColMaker: public TreeUpdater { if (e.stats.sum_hess >= param_.min_child_weight && c.sum_hess >= param_.min_child_weight) { bst_float loss_chg; + GradStats left_sum; + GradStats right_sum; if (d_step == -1) { - loss_chg = static_cast( - spliteval_->ComputeSplitScore(nid, fid, c, e.stats) - - snode_[nid].root_gain); + left_sum = c; + right_sum = e.stats; } else { - loss_chg = static_cast( - spliteval_->ComputeSplitScore(nid, fid, e.stats, c) - - snode_[nid].root_gain); + left_sum = e.stats; + right_sum = c; } + loss_chg = static_cast( + spliteval_->ComputeSplitScore(nid, fid, left_sum, right_sum) - + snode_[nid].root_gain); const bst_float gap = std::abs(e.last_fvalue) + kRtEps; const bst_float delta = d_step == +1 ? gap: -gap; - e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1); + e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, left_sum, right_sum); } } } @@ -637,7 +648,16 @@ class ColMaker: public TreeUpdater { NodeEntry &e = snode_[nid]; // now we know the solution in snode[nid], set split if (e.best.loss_chg > kRtEps) { - p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft()); + bst_float left_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.left_sum) * + param_.learning_rate; + bst_float right_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.right_sum) * + param_.learning_rate; + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, + e.stats.sum_hess); } else { (*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate); } diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index ded04b1c5..63c886b5e 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -296,7 +296,8 @@ inline void Dense2SparseTree(RegTree* p_tree, for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) { const DeviceNodeStats& n = h_nodes[gpu_nid]; if (!n.IsUnused() && !n.IsLeaf()) { - tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir); + tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir, n.weight, 0.0f, + 0.0f, n.root_gain, n.sum_gradients.GetHess()); tree.Stat(nid).loss_chg = n.root_gain; tree.Stat(nid).base_weight = n.weight; tree.Stat(nid).sum_hess = n.sum_gradients.GetHess(); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 4946aa3e5..831aa11b8 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1182,42 +1182,35 @@ class GPUHistMakerSpecialised{ } void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { - // Add new leaves RegTree& tree = *p_tree; - tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue, - candidate.split.dir == kLeftDir); - auto& parent = tree[candidate.nid]; - tree.Stat(candidate.nid).loss_chg = candidate.split.loss_chg; - - // Set up child constraints - node_value_constraints_.resize(tree.GetNodes().size()); GradStats left_stats(param_); left_stats.Add(candidate.split.left_sum); GradStats right_stats(param_); right_stats.Add(candidate.split.right_sum); - node_value_constraints_[candidate.nid].SetChild( - param_, parent.SplitIndex(), left_stats, right_stats, - &node_value_constraints_[parent.LeftChild()], - &node_value_constraints_[parent.RightChild()]); - - // Configure left child + GradStats parent_sum(param_); + parent_sum.Add(left_stats); + parent_sum.Add(right_stats); + node_value_constraints_.resize(tree.GetNodes().size()); + auto base_weight = node_value_constraints_[candidate.nid].CalcWeight(param_, parent_sum); auto left_weight = - node_value_constraints_[parent.LeftChild()].CalcWeight(param_, left_stats); - tree[parent.LeftChild()].SetLeaf(left_weight * param_.learning_rate, 0); - tree.Stat(parent.LeftChild()).base_weight = left_weight; - tree.Stat(parent.LeftChild()).sum_hess = candidate.split.left_sum.GetHess(); - - // Configure right child + node_value_constraints_[candidate.nid].CalcWeight(param_, left_stats)*param_.learning_rate; auto right_weight = - node_value_constraints_[parent.RightChild()].CalcWeight(param_, right_stats); - tree[parent.RightChild()].SetLeaf(right_weight * param_.learning_rate, 0); - tree.Stat(parent.RightChild()).base_weight = right_weight; - tree.Stat(parent.RightChild()).sum_hess = candidate.split.right_sum.GetHess(); + node_value_constraints_[candidate.nid].CalcWeight(param_, right_stats)*param_.learning_rate; + tree.ExpandNode(candidate.nid, candidate.split.findex, + candidate.split.fvalue, candidate.split.dir == kLeftDir, + base_weight, left_weight, right_weight, + candidate.split.loss_chg, parent_sum.sum_hess); + // Set up child constraints + node_value_constraints_.resize(tree.GetNodes().size()); + node_value_constraints_[candidate.nid].SetChild( + param_, tree[candidate.nid].SplitIndex(), left_stats, right_stats, + &node_value_constraints_[tree[candidate.nid].LeftChild()], + &node_value_constraints_[tree[candidate.nid].RightChild()]); // Store sum gradients for (auto& shard : shards_) { - shard->node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum; - shard->node_sum_gradients[parent.RightChild()] = candidate.split.right_sum; + shard->node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum; + shard->node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum; } } diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 448f1cb92..d0fdd8ac3 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -192,7 +192,8 @@ class HistMaker: public BaseMaker { c.SetSubstract(node_sum, s); if (c.sum_hess >= param_.min_child_weight) { double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain; - if (best->Update(static_cast(loss_chg), fid, hist.cut[i], false)) { + if (best->Update(static_cast(loss_chg), fid, hist.cut[i], + false, s, c)) { *left_sum = s; } } @@ -205,7 +206,7 @@ class HistMaker: public BaseMaker { c.SetSubstract(node_sum, s); if (c.sum_hess >= param_.min_child_weight) { double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain; - if (best->Update(static_cast(loss_chg), fid, hist.cut[i-1], true)) { + if (best->Update(static_cast(loss_chg), fid, hist.cut[i-1], true, c, s)) { *left_sum = c; } } @@ -243,8 +244,18 @@ class HistMaker: public BaseMaker { p_tree->Stat(nid).loss_chg = best.loss_chg; // now we know the solution in snode[nid], set split if (best.loss_chg > kRtEps) { + bst_float base_weight = node_sum.CalcWeight(param_); + bst_float left_leaf_weight = + CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) * + param_.learning_rate; + bst_float right_leaf_weight = + CalcWeight(param_, best.right_sum.sum_grad, + best.right_sum.sum_hess) * + param_.learning_rate; p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value, - best.DefaultLeft()); + best.DefaultLeft(), base_weight, left_leaf_weight, + right_leaf_weight, best.loss_chg, + node_sum.sum_hess); // right side sum TStats right_sum; right_sum.SetSubstract(node_sum, left_sum[wid]); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 87d41aacf..62eb57de0 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -429,8 +429,13 @@ void QuantileHistMaker::Builder::ApplySplit(int nid, /* 1. Create child nodes */ NodeEntry& e = snode_[nid]; + bst_float left_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate; + bst_float right_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, - e.best.DefaultLeft()); + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); /* 2. Categorize member rows */ const auto nthread = static_cast(this->nthread_); @@ -698,6 +703,7 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step, spliteval_->ComputeSplitScore(nodeID, fid, e, c) - snode.root_gain); split_pt = cut_val[i]; + best.Update(loss_chg, fid, split_pt, d_step == -1, e, c); } else { // backward enumeration: split at left bound of each bin loss_chg = static_cast( @@ -709,8 +715,8 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step, } else { split_pt = cut_val[i - 1]; } + best.Update(loss_chg, fid, split_pt, d_step == -1, c, e); } - best.Update(loss_chg, fid, split_pt, d_step == -1); } } } diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index c19a6a1d1..9e94d5dae 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -281,12 +281,21 @@ class SketchMaker: public BaseMaker { const int nid = qexpand_[wid]; const SplitEntry &best = sol[wid]; // set up the values - p_tree->Stat(nid).loss_chg = best.loss_chg; this->SetStats(nid, node_stats_[nid], p_tree); // now we know the solution in snode[nid], set split if (best.loss_chg > kRtEps) { + bst_float base_weight = node_stats_[nid].CalcWeight(param_); + bst_float left_leaf_weight = + CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) * + param_.learning_rate; + bst_float right_leaf_weight = + CalcWeight(param_, best.right_sum.sum_grad, + best.right_sum.sum_hess) * + param_.learning_rate; p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value, - best.DefaultLeft()); + best.DefaultLeft(), base_weight, left_leaf_weight, + right_leaf_weight, best.loss_chg, + node_stats_[nid].sum_hess); } else { (*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate); } @@ -336,7 +345,9 @@ class SketchMaker: public BaseMaker { if (s.sum_hess >= param_.min_child_weight && c.sum_hess >= param_.min_child_weight) { double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain; - best->Update(static_cast(loss_chg), fid, fsplits[i], false); + best->Update(static_cast(loss_chg), fid, fsplits[i], false, + GradStats(s.pos_grad - s.neg_grad , s.sum_hess), + GradStats(c.pos_grad - c.neg_grad, c.sum_hess)); } // backward c.SetSubstract(feat_sum, s); @@ -344,7 +355,9 @@ class SketchMaker: public BaseMaker { if (s.sum_hess >= param_.min_child_weight && c.sum_hess >= param_.min_child_weight) { double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain; - best->Update(static_cast(loss_chg), fid, fsplits[i], true); + best->Update(static_cast(loss_chg), fid, fsplits[i], true, + GradStats(s.pos_grad - s.neg_grad, s.sum_hess), + GradStats(c.pos_grad - c.neg_grad, c.sum_hess)); } } { @@ -355,8 +368,10 @@ class SketchMaker: public BaseMaker { c.sum_hess >= param_.min_child_weight) { bst_float cpt = fsplits.back(); double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain; - best->Update(static_cast(loss_chg), - fid, cpt + std::abs(cpt) + 1.0f, false); + best->Update(static_cast(loss_chg), fid, + cpt + std::abs(cpt) + 1.0f, false, + GradStats(s.pos_grad - s.neg_grad, s.sum_hess), + GradStats(c.pos_grad - c.neg_grad, c.sum_hess)); } } } diff --git a/tests/cpp/tree/test_param.cc b/tests/cpp/tree/test_param.cc index 6f2a84c74..3f4e50ba2 100644 --- a/tests/cpp/tree/test_param.cc +++ b/tests/cpp/tree/test_param.cc @@ -82,12 +82,15 @@ TEST(Param, SplitEntry) { xgboost::tree::SplitEntry se2; EXPECT_FALSE(se1.Update(se2)); - EXPECT_FALSE(se2.Update(-1, 100, 0, true)); - ASSERT_TRUE(se2.Update(1, 100, 0, true)); + EXPECT_FALSE(se2.Update(-1, 100, 0, true, xgboost::tree::GradStats(), + xgboost::tree::GradStats())); + ASSERT_TRUE(se2.Update(1, 100, 0, true, xgboost::tree::GradStats(), + xgboost::tree::GradStats())); ASSERT_TRUE(se1.Update(se2)); xgboost::tree::SplitEntry se3; - se3.Update(2, 101, 0, false); + se3.Update(2, 101, 0, false, xgboost::tree::GradStats(), + xgboost::tree::GradStats()); xgboost::tree::SplitEntry::Reduce(se2, se3); EXPECT_EQ(se2.SplitIndex(), 101); EXPECT_FALSE(se2.DefaultLeft()); diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index 82ffdb3c7..8206a39be 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -38,22 +38,13 @@ TEST(Updater, Prune) { pruner->Init(cfg); // loss_chg < min_split_loss; - tree.ExpandNode(0, 0, 0, true); - int cleft = tree[0].LeftChild(); - int cright = tree[0].RightChild(); - tree[cleft].SetLeaf(0.3f, 0); - tree[cright].SetLeaf(0.4f, 0); + tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f); pruner->Update(&gpair, dmat->get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 0); // loss_chg > min_split_loss; - tree.ExpandNode(0, 0, 0, true); - cleft = tree[0].LeftChild(); - cright = tree[0].RightChild(); - tree[cleft].SetLeaf(0.3f, 0); - tree[cright].SetLeaf(0.4f, 0); - tree.Stat(0).loss_chg = 11; + tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f); pruner->Update(&gpair, dmat->get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 2); diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index e9cf565aa..cbd06d609 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -29,12 +29,9 @@ TEST(Updater, Refresh) { std::vector trees {&tree}; std::unique_ptr refresher(TreeUpdater::Create("refresh")); - tree.ExpandNode(0, 0, 0, true); + tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f); int cleft = tree[0].LeftChild(); int cright = tree[0].RightChild(); - tree[cleft].SetLeaf(0.2f, 0); - tree[cright].SetLeaf(0.8f, 0); - tree[0].SetSplit(2, 0.2f); tree.Stat(cleft).base_weight = 1.2; tree.Stat(cright).base_weight = 1.3;