From 3d81c48d3fa9a2a2ceb0d7388d3ac1436af7d8ee Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 13 Dec 2018 10:28:38 +1300 Subject: [PATCH] Remove leaf vector, add tree serialisation test, fix Windows tests (#3989) --- include/xgboost/tree_model.h | 18 ------ src/tree/param.h | 2 - src/tree/updater_colmaker.cc | 1 - src/tree/updater_histmaker.cc | 1 - src/tree/updater_quantile_hist.cc | 1 - src/tree/updater_refresh.cc | 1 - src/tree/updater_skmaker.cc | 4 -- tests/cpp/tree/test_quantile_hist.cc | 16 ++--- tests/cpp/tree/test_tree_model.cc | 87 ++++++++++++++++++++++++++++ 9 files changed, 95 insertions(+), 36 deletions(-) create mode 100644 tests/cpp/tree/test_tree_model.cc diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 8729b74d3..554c7cd92 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -198,8 +198,6 @@ class TreeModel { std::vector deleted_nodes_; // stats of nodes std::vector stats_; - // leaf vector, that is used to store additional information - std::vector leaf_vector_; // allocate a new node, // !!!!!! NOTE: may cause BUG here, nodes.resize inline int AllocNode() { @@ -214,7 +212,6 @@ class TreeModel { << "number of nodes in the tree exceed 2^31"; nodes_.resize(param.num_nodes); stats_.resize(param.num_nodes); - leaf_vector_.resize(param.num_nodes * param.size_leaf_vector); return nd; } // delete a tree node, keep the parent field to allow trace back @@ -284,22 +281,11 @@ class TreeModel { inline const NodeStat& Stat(int nid) const { return stats_[nid]; } - /*! \brief get leaf vector given nid */ - inline bst_float* Leafvec(int nid) { - if (leaf_vector_.size() == 0) return nullptr; - return& leaf_vector_[nid * param.size_leaf_vector]; - } - /*! \brief get leaf vector given nid */ - inline const bst_float* Leafvec(int nid) const { - if (leaf_vector_.size() == 0) return nullptr; - return& leaf_vector_[nid * param.size_leaf_vector]; - } /*! \brief initialize the model */ inline void InitModel() { param.num_nodes = param.num_roots; nodes_.resize(param.num_nodes); stats_.resize(param.num_nodes); - leaf_vector_.resize(param.num_nodes * param.size_leaf_vector, 0.0f); for (int i = 0; i < param.num_nodes; i ++) { nodes_[i].SetLeaf(0.0f); nodes_[i].SetParent(-1); @@ -318,9 +304,6 @@ class TreeModel { sizeof(Node) * nodes_.size()); CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(NodeStat) * stats_.size()), sizeof(NodeStat) * stats_.size()); - if (param.size_leaf_vector != 0) { - CHECK(fi->Read(&leaf_vector_)); - } // chg deleted nodes deleted_nodes_.resize(0); for (int i = param.num_roots; i < param.num_nodes; ++i) { @@ -339,7 +322,6 @@ class TreeModel { CHECK_NE(param.num_nodes, 0); fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()); fo->Write(dmlc::BeginPtr(stats_), sizeof(NodeStat) * nodes_.size()); - if (param.size_leaf_vector != 0) fo->Write(leaf_vector_); } /*! * \brief add child nodes to node diff --git a/src/tree/param.h b/src/tree/param.h index 8f607647e..edc3fbaff 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -408,8 +408,6 @@ template } /*! \return whether the statistics is not used yet */ inline bool Empty() const { return sum_hess == 0.0; } - /*! \brief set leaf vector value based on statistics */ - inline void SetLeafVec(const TrainParam& param, bst_float* vec) const {} // constructor to allow inheritance GradStats() = default; /*! \brief add statistics to the data */ diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 0ab671f55..abfd0eb51 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -130,7 +130,6 @@ class ColMaker: public TreeUpdater { p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; p_tree->Stat(nid).base_weight = snode_[nid].weight; p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.sum_hess); - snode_[nid].stats.SetLeafVec(param_, p_tree->Leafvec(nid)); } } diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 729324b81..936ccd498 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -263,7 +263,6 @@ class HistMaker: public BaseMaker { inline void SetStats(RegTree *p_tree, int nid, const TStats &node_sum) { p_tree->Stat(nid).base_weight = static_cast(node_sum.CalcWeight(param_)); p_tree->Stat(nid).sum_hess = static_cast(node_sum.sum_hess); - node_sum.SetLeafVec(param_, p_tree->Leafvec(nid)); } }; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 6a5dde326..a86555cd7 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -203,7 +203,6 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; p_tree->Stat(nid).base_weight = snode_[nid].weight; p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.sum_hess); - snode_[nid].stats.SetLeafVec(param_, p_tree->Leafvec(nid)); } pruner_->Update(gpair, p_fmat, std::vector{p_tree}); diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 92ae5be30..a9ca17415 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -127,7 +127,6 @@ class TreeRefresher: public TreeUpdater { RegTree &tree = *p_tree; tree.Stat(nid).base_weight = static_cast(gstats[nid].CalcWeight(param_)); tree.Stat(nid).sum_hess = static_cast(gstats[nid].sum_hess); - gstats[nid].SetLeafVec(param_, tree.Leafvec(nid)); if (tree[nid].IsLeaf()) { if (param_.refresh_leaf) { tree[nid].SetLeaf(tree.Stat(nid).base_weight * param_.learning_rate); diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index 05eea7d14..405d1c2bf 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -128,9 +128,6 @@ class SketchMaker: public BaseMaker { inline static void Reduce(SKStats &a, const SKStats &b) { // NOLINT(*) a.Add(b); } - /*! \brief set leaf vector value based on statistics */ - inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const { - } }; inline void BuildSketch(const std::vector &gpair, DMatrix *p_fmat, @@ -303,7 +300,6 @@ class SketchMaker: public BaseMaker { inline void SetStats(int nid, const SKStats &node_sum, RegTree *p_tree) { p_tree->Stat(nid).base_weight = static_cast(node_sum.CalcWeight(param_)); p_tree->Stat(nid).sum_hess = static_cast(node_sum.sum_hess); - node_sum.SetLeafVec(param_, p_tree->Leafvec(nid)); } inline void EnumerateSplit(const WXQSketch::Summary &pos_grad, const WXQSketch::Summary &neg_grad, diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index b4336b857..5ac8575a1 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -48,14 +48,14 @@ class QuantileHistMock : public QuantileHistMaker { BuildHist(gpair, row_set_collection_[nid], gmat, quantile_index_block, hist_[nid]); std::vector solution { - {0.27, 0.29}, {0.27, 0.29}, {0.47, 0.49}, - {0.27, 0.29}, {0.57, 0.59}, {0.26, 0.27}, - {0.37, 0.39}, {0.23, 0.24}, {0.37, 0.39}, - {0.27, 0.28}, {0.27, 0.29}, {0.37, 0.39}, - {0.26, 0.27}, {0.23, 0.24}, {0.57, 0.59}, - {0.47, 0.49}, {0.47, 0.49}, {0.37, 0.39}, - {0.26, 0.27}, {0.23, 0.24}, {0.27, 0.28}, - {0.57, 0.59}, {0.23, 0.24}, {0.47, 0.49}}; + {0.27f, 0.29f}, {0.27f, 0.29f}, {0.47f, 0.49f}, + {0.27f, 0.29f}, {0.57f, 0.59f}, {0.26f, 0.27f}, + {0.37f, 0.39f}, {0.23f, 0.24f}, {0.37f, 0.39f}, + {0.27f, 0.28f}, {0.27f, 0.29f}, {0.37f, 0.39f}, + {0.26f, 0.27f}, {0.23f, 0.24f}, {0.57f, 0.59f}, + {0.47f, 0.49f}, {0.47f, 0.49f}, {0.37f, 0.39f}, + {0.26f, 0.27f}, {0.23f, 0.24f}, {0.27f, 0.28f}, + {0.57f, 0.59f}, {0.23f, 0.24f}, {0.47f, 0.49f}}; for (size_t i = 0; i < hist_[nid].size; ++i) { GradientPairPrecise sol = solution[i]; diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc new file mode 100644 index 000000000..21c2eeeda --- /dev/null +++ b/tests/cpp/tree/test_tree_model.cc @@ -0,0 +1,87 @@ +// Copyright by Contributors +#include +#include +#include "../helpers.h" +#include "dmlc/filesystem.h" + +namespace xgboost { +// Manually construct tree in binary format +// Do not use structs in case they change +// We want to preserve backwards compatibility +TEST(Tree, Load) { + dmlc::TemporaryDirectory tempdir; + const std::string tmp_file = tempdir.path + "/tree.model"; + std::unique_ptr fo(dmlc::Stream::Create(tmp_file.c_str(), "w")); + + // Write params + EXPECT_EQ(sizeof(TreeParam), (31 + 6) * sizeof(int)); + int num_roots = 1; + int num_nodes = 2; + int num_deleted = 0; + int max_depth = 1; + int num_feature = 0; + int size_leaf_vector = 0; + int reserved[31]; + fo->Write(&num_roots, sizeof(int)); + fo->Write(&num_nodes, sizeof(int)); + fo->Write(&num_deleted, sizeof(int)); + fo->Write(&max_depth, sizeof(int)); + fo->Write(&num_feature, sizeof(int)); + fo->Write(&size_leaf_vector, sizeof(int)); + fo->Write(reserved, sizeof(int) * 31); + + // Write 2 nodes + EXPECT_EQ(sizeof(RegTree::Node), + 3 * sizeof(int) + 1 * sizeof(unsigned) + sizeof(float)); + int parent = -1; + int cleft = 1; + int cright = -1; + unsigned sindex = 5; + float split_or_weight = 0.5; + fo->Write(&parent, sizeof(int)); + fo->Write(&cleft, sizeof(int)); + fo->Write(&cright, sizeof(int)); + fo->Write(&sindex, sizeof(unsigned)); + fo->Write(&split_or_weight, sizeof(float)); + parent = 0; + cleft = -1; + cright = -1; + sindex = 2; + split_or_weight = 0.1; + fo->Write(&parent, sizeof(int)); + fo->Write(&cleft, sizeof(int)); + fo->Write(&cright, sizeof(int)); + fo->Write(&sindex, sizeof(unsigned)); + fo->Write(&split_or_weight, sizeof(float)); + + // Write 2x node stats + EXPECT_EQ(sizeof(RTreeNodeStat), 3 * sizeof(float) + sizeof(int)); + bst_float loss_chg = 5.0; + bst_float sum_hess = 1.0; + bst_float base_weight = 3.0; + int leaf_child_cnt = 0; + fo->Write(&loss_chg, sizeof(float)); + fo->Write(&sum_hess, sizeof(float)); + fo->Write(&base_weight, sizeof(float)); + fo->Write(&leaf_child_cnt, sizeof(int)); + + loss_chg = 50.0; + sum_hess = 10.0; + base_weight = 30.0; + leaf_child_cnt = 0; + fo->Write(&loss_chg, sizeof(float)); + fo->Write(&sum_hess, sizeof(float)); + fo->Write(&base_weight, sizeof(float)); + fo->Write(&leaf_child_cnt, sizeof(int)); + fo.reset(); + std::unique_ptr fi(dmlc::Stream::Create(tmp_file.c_str(), "r")); + + xgboost::RegTree tree; + tree.Load(fi.get()); + EXPECT_EQ(tree.GetDepth(1), 1); + EXPECT_EQ(tree[0].SplitCond(), 0.5f); + EXPECT_EQ(tree[0].SplitIndex(), 5); + EXPECT_EQ(tree[1].LeafValue(), 0.1f); + EXPECT_TRUE(tree[1].IsLeaf()); +} +} // namespace xgboost