diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index d7f730f5b..e7f6dc8ec 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace xgboost { @@ -88,6 +89,10 @@ struct RTreeNodeStat { bst_float base_weight; /*! \brief number of child that is leaf node known up to now */ int leaf_child_cnt {0}; + + RTreeNodeStat() = default; + RTreeNodeStat(float loss_chg, float sum_hess, float weight) : + loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {} bool operator==(const RTreeNodeStat& b) const { return loss_chg == b.loss_chg && sum_hess == b.sum_hess && base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt; @@ -101,8 +106,9 @@ struct RTreeNodeStat { class RegTree : public Model { public: using SplitCondT = bst_float; - static constexpr int32_t kInvalidNodeId {-1}; + static constexpr bst_node_t kInvalidNodeId {-1}; static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits::max(); + static constexpr bst_node_t kRoot { 0 }; /*! \brief tree node */ class Node { @@ -321,6 +327,31 @@ class RegTree : public Model { return nodes_ == b.nodes_ && stats_ == b.stats_ && deleted_nodes_ == b.deleted_nodes_ && param == b.param; } + /* \brief Iterate through all nodes in this tree. + * + * \param Function that accepts a node index, and returns false when iteration should + * stop, otherwise returns true. + */ + template void WalkTree(Func func) const { + std::stack nodes; + nodes.push(kRoot); + auto &self = *this; + while (!nodes.empty()) { + auto nidx = nodes.top(); + nodes.pop(); + if (!func(nidx)) { + return; + } + auto left = self[nidx].LeftChild(); + auto right = self[nidx].RightChild(); + if (left != RegTree::kInvalidNodeId) { + nodes.push(left); + } + if (right != RegTree::kInvalidNodeId) { + nodes.push(right); + } + } + } /*! * \brief Compares whether 2 trees are equal from a user's perspective. The equality * compares only non-deleted nodes. @@ -341,13 +372,16 @@ class RegTree : public Model { * \param right_leaf_weight The right leaf weight for prediction, modified by learning rate. * \param loss_change The loss change. * \param sum_hess The sum hess. - * \param leaf_right_child The right child index of leaf, by default kInvalidNodeId, - * some updaters use the right child index of leaf as a marker + * \param left_sum The sum hess of left leaf. + * \param right_sum The sum hess of right leaf. + * \param leaf_right_child The right child index of leaf, by default kInvalidNodeId, + * some updaters use the right child index of leaf as a marker */ void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, - bst_float loss_change, float sum_hess, + bst_float loss_change, float sum_hess, float left_sum, + float right_sum, bst_node_t leaf_right_child = kInvalidNodeId) { int pleft = this->AllocNode(); int pright = this->AllocNode(); @@ -363,9 +397,9 @@ class RegTree : public Model { nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child); nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child); - this->Stat(nid).loss_chg = loss_change; - this->Stat(nid).base_weight = base_weight; - this->Stat(nid).sum_hess = sum_hess; + this->Stat(nid) = {loss_change, sum_hess, base_weight}; + this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight}; + this->Stat(pright) = {0.0f, right_sum, right_leaf_weight}; } /*! @@ -402,6 +436,10 @@ class RegTree : public Model { return param.num_nodes - 1 - param.num_deleted; } + /* \brief Count number of leaves in tree. */ + bst_node_t GetNumLeaves() const; + bst_node_t GetNumSplitNodes() const; + /*! * \brief dense feature vector that can be taken by RegTree * and can be construct from sparse feature vector. diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 61717fc43..e8046d109 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -607,6 +607,8 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot") return new GraphvizGenerator(fmap, attrs, with_stats); }); +constexpr bst_node_t RegTree::kRoot; + std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const { @@ -623,26 +625,40 @@ bool RegTree::Equal(const RegTree& b) const { if (NumExtraNodes() != b.NumExtraNodes()) { return false; } - - std::stack nodes; - nodes.push(0); - auto& self = *this; - while (!nodes.empty()) { - auto nid = nodes.top(); - nodes.pop(); - if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) { + auto const& self = *this; + bool ret { true }; + this->WalkTree([&self, &b, &ret](bst_node_t nidx) { + if (!(self.nodes_.at(nidx) == b.nodes_.at(nidx))) { + ret = false; return false; } - auto left = self[nid].LeftChild(); - auto right = self[nid].RightChild(); - if (left != RegTree::kInvalidNodeId) { - nodes.push(left); - } - if (right != RegTree::kInvalidNodeId) { - nodes.push(right); - } - } - return true; + return true; + }); + return ret; +} + +bst_node_t RegTree::GetNumLeaves() const { + bst_node_t leaves { 0 }; + auto const& self = *this; + this->WalkTree([&leaves, &self](bst_node_t nidx) { + if (self[nidx].IsLeaf()) { + leaves++; + } + return true; + }); + return leaves; +} + +bst_node_t RegTree::GetNumSplitNodes() const { + bst_node_t splits { 0 }; + auto const& self = *this; + this->WalkTree([&splits, &self](bst_node_t nidx) { + if (!self[nidx].IsLeaf()) { + splits++; + } + return true; + }); + return splits; } void RegTree::Load(dmlc::Stream* fi) { diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 7aa738461..690a8bcce 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -499,7 +499,9 @@ class ColMaker: public TreeUpdater { p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft(), e.weight, left_leaf_weight, right_leaf_weight, e.best.loss_chg, - e.stats.sum_hess, 0); + e.stats.sum_hess, + e.best.left_sum.GetHess(), e.best.right_sum.GetHess(), + 0); } else { (*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate); } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 65085bb70..2f78730a4 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -814,7 +814,8 @@ struct GPUHistMakerDevice { tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue, candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight, - candidate.split.loss_chg, parent_sum.sum_hess); + candidate.split.loss_chg, parent_sum.sum_hess, + left_stats.GetHess(), right_stats.GetHess()); // Set up child constraints node_value_constraints.resize(tree.GetNodes().size()); node_value_constraints[candidate.nid].SetChild( diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 725634b9e..dd57354f9 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -249,7 +249,8 @@ class HistMaker: public BaseMaker { p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value, best.DefaultLeft(), base_weight, left_leaf_weight, right_leaf_weight, best.loss_chg, - node_sum.sum_hess); + node_sum.sum_hess, + best.left_sum.GetHess(), best.right_sum.GetHess()); GradStats right_sum; right_sum.SetSubstract(node_sum, left_sum[wid]); auto left_child = (*p_tree)[nid].LeftChild(); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 3e077eed3..d9003841c 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -263,7 +263,8 @@ void QuantileHistMaker::Builder::AddSplitsToTree( spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft(), e.weight, left_leaf_weight, - right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); + right_leaf_weight, e.best.loss_chg, e.stats.sum_hess, + e.best.left_sum.GetHess(), e.best.right_sum.GetHess()); int left_id = (*p_tree)[nid].LeftChild(); int right_id = (*p_tree)[nid].RightChild(); @@ -410,7 +411,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft(), e.weight, left_leaf_weight, - right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); + right_leaf_weight, e.best.loss_chg, e.stats.sum_hess, + e.best.left_sum.GetHess(), e.best.right_sum.GetHess()); this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree); diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index 9d57f50e0..69cb4e58b 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -289,7 +289,8 @@ class SketchMaker: public BaseMaker { p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value, best.DefaultLeft(), base_weight, left_leaf_weight, right_leaf_weight, best.loss_chg, - node_stats_[nid].sum_hess); + node_stats_[nid].sum_hess, + best.left_sum.GetHess(), best.right_sum.GetHess()); } else { (*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate); } diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index e066db989..dbe910a8f 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -42,13 +42,15 @@ TEST(Updater, Prune) { pruner->Configure(cfg); // loss_chg < min_split_loss; - tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f); + tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); pruner->Update(&gpair, p_dmat.get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 0); // loss_chg > min_split_loss; - tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f); + tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); pruner->Update(&gpair, p_dmat.get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 2); @@ -63,10 +65,12 @@ TEST(Updater, Prune) { // loss_chg > min_split_loss tree.ExpandNode(tree[0].LeftChild(), 0, 0.5f, true, 0.3, 0.4, 0.5, - /*loss_chg=*/18.0f, 0.0f); + /*loss_chg=*/18.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); tree.ExpandNode(tree[0].RightChild(), 0, 0.5f, true, 0.3, 0.4, 0.5, - /*loss_chg=*/19.0f, 0.0f); + /*loss_chg=*/19.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); cfg.emplace_back(std::make_pair("max_depth", "1")); pruner->Configure(cfg); pruner->Update(&gpair, p_dmat.get(), trees); @@ -75,7 +79,8 @@ TEST(Updater, Prune) { tree.ExpandNode(tree[0].LeftChild(), 0, 0.5f, true, 0.3, 0.4, 0.5, - /*loss_chg=*/18.0f, 0.0f); + /*loss_chg=*/18.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); cfg.emplace_back(std::make_pair("min_split_loss", "0")); pruner->Configure(cfg); pruner->Update(&gpair, p_dmat.get(), trees); diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index e8643ce1e..3689940fd 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -34,7 +34,8 @@ TEST(Updater, Refresh) { std::vector trees {&tree}; std::unique_ptr refresher(TreeUpdater::Create("refresh", &lparam)); - tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f); + tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); int cleft = tree[0].LeftChild(); int cright = tree[0].RightChild(); diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 0c7d87e6c..406e9e62f 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -88,13 +88,13 @@ TEST(Tree, Load) { TEST(Tree, AllocateNode) { RegTree tree; - tree.ExpandNode( - 0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); tree.CollapseToLeaf(0, 0); ASSERT_EQ(tree.NumExtraNodes(), 0); - tree.ExpandNode( - 0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); ASSERT_EQ(tree.NumExtraNodes(), 2); auto& nodes = tree.GetNodes(); @@ -107,18 +107,18 @@ RegTree ConstructTree() { RegTree tree; tree.ExpandNode( /*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f, - /*default_left=*/true, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f, + /*right_sum=*/0.0f); auto left = tree[0].LeftChild(); auto right = tree[0].RightChild(); tree.ExpandNode( /*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f, - /*default_left=*/false, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + /*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f, + /*right_sum=*/0.0f); tree.ExpandNode( /*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f, - /*default_left=*/false, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + /*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f, + /*right_sum=*/0.0f); return tree; } @@ -222,7 +222,8 @@ TEST(Tree, DumpDot) { TEST(Tree, JsonIO) { RegTree tree; - tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); Json j_tree{Object()}; tree.SaveModel(&j_tree); @@ -246,8 +247,10 @@ TEST(Tree, JsonIO) { auto left = tree[0].LeftChild(); auto right = tree[0].RightChild(); - tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); - tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); + tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, /*right_sum=*/0.0f); tree.SaveModel(&j_tree); tree.ChangeToLeaf(1, 1.0f); diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc new file mode 100644 index 000000000..d3497c2b2 --- /dev/null +++ b/tests/cpp/tree/test_tree_stat.cc @@ -0,0 +1,59 @@ +#include +#include +#include + +#include "../helpers.h" + +namespace xgboost { +class UpdaterTreeStatTest : public ::testing::Test { + protected: + std::shared_ptr p_dmat_; + HostDeviceVector gpairs_; + size_t constexpr static kRows = 10; + size_t constexpr static kCols = 10; + + protected: + void SetUp() override { + p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatix(true); + auto g = GenerateRandomGradients(kRows); + gpairs_.Resize(kRows); + gpairs_.Copy(g); + } + + void RunTest(std::string updater) { + auto tparam = CreateEmptyGenericParam(0); + auto up = std::unique_ptr{ + TreeUpdater::Create(updater, &tparam)}; + up->Configure(Args{}); + RegTree tree; + tree.param.num_feature = kCols; + up->Update(&gpairs_, p_dmat_.get(), {&tree}); + + tree.WalkTree([&tree](bst_node_t nidx) { + if (tree[nidx].IsLeaf()) { + // 1.0 is the default `min_child_weight`. + CHECK_GE(tree.Stat(nidx).sum_hess, 1.0); + } + return true; + }); + } +}; + +#if defined(XGBOOST_USE_CUDA) +TEST_F(UpdaterTreeStatTest, GPUHist) { + this->RunTest("grow_gpu_hist"); +} +#endif // defined(XGBOOST_USE_CUDA) + +TEST_F(UpdaterTreeStatTest, Hist) { + this->RunTest("grow_quantile_histmaker"); +} + +TEST_F(UpdaterTreeStatTest, Exact) { + this->RunTest("grow_colmaker"); +} + +TEST_F(UpdaterTreeStatTest, Approx) { + this->RunTest("grow_histmaker"); +} +} // namespace xgboost