diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 02ac19f9a..d7f730f5b 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -65,6 +65,7 @@ struct TreeParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1); DMLC_DECLARE_FIELD(num_feature) .describe("Number of features used in tree construction."); + DMLC_DECLARE_FIELD(num_deleted); DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0) .describe("Size of leaf vector, reserved for vector tree"); } @@ -114,6 +115,7 @@ class RegTree : public Model { Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left) : parent_{parent}, cleft_{cleft}, cright_{cright} { + this->SetParent(parent_); this->SetSplit(split_ind, split_cond, default_left); } @@ -319,6 +321,13 @@ class RegTree : public Model { return nodes_ == b.nodes_ && stats_ == b.stats_ && deleted_nodes_ == b.deleted_nodes_ && param == b.param; } + /*! + * \brief Compares whether 2 trees are equal from a user's perspective. The equality + * compares only non-deleted nodes. + * + * \parm b The other tree. + */ + bool Equal(const RegTree& b) const; /** * \brief Expands a leaf node into two additional leaf nodes. diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 2508f130d..69f972279 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -40,6 +40,13 @@ class TreeUpdater : public Configurable { * \param args arguments to the objective function. */ virtual void Configure(const Args& args) = 0; + /*! \brief Whether this updater can be used for updating existing trees. + * + * Some updaters are used for building new trees (like `hist`), while some others are + * used for modifying existing trees (like `prune`). Return true if it can modify + * existing trees. + */ + virtual bool CanModifyTree() const { return false; } /*! * \brief perform update to the tree models * \param gpair the gradient pair statistics of the data diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index f9bcc71b0..96b395d79 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -273,6 +273,12 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, new_trees.push_back(ptr.get()); ret->push_back(std::move(ptr)); } else if (tparam_.process_type == TreeProcessType::kUpdate) { + for (auto const& up : updaters_) { + CHECK(up->CanModifyTree()) + << "Updater: `" << up->Name() << "` " + << "can not be used to modify existing trees. " + << "Set `process_type` to `default` if you want to build new trees."; + } CHECK_LT(model_.trees.size(), model_.trees_to_update.size()); // move an existing tree from trees_to_update auto t = std::move(model_.trees_to_update[model_.trees.size() + diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index f6cdc06e2..65544e9cd 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -14,6 +14,7 @@ #include #include #include +#include #include "param.h" #include "../common/common.h" @@ -618,6 +619,32 @@ std::string RegTree::DumpModel(const FeatureMap& fmap, return result; } +bool RegTree::Equal(const RegTree& b) const { + if (NumExtraNodes() != b.NumExtraNodes()) { + return false; + } + + std::stack nodes; + nodes.push(0); + auto& self = *this; + while (!nodes.empty()) { + auto nid = nodes.top(); + nodes.pop(); + if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) { + return false; + } + auto left = self[nid].LeftChild(); + auto right = self[nid].RightChild(); + if (left != RegTree::kInvalidNodeId) { + nodes.push(left); + } + if (right != RegTree::kInvalidNodeId) { + nodes.push(right); + } + } + return true; +} + void RegTree::Load(dmlc::Stream* fi) { CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam)); nodes_.resize(param.num_nodes); @@ -673,6 +700,9 @@ void RegTree::LoadModel(Json const& in) { auto const& default_left = get(in["default_left"]); CHECK_EQ(default_left.size(), n_nodes); + stats_.clear(); + nodes_.clear(); + stats_.resize(n_nodes); nodes_.resize(n_nodes); for (int32_t i = 0; i < n_nodes; ++i) { @@ -692,13 +722,19 @@ void RegTree::LoadModel(Json const& in) { n = Node{left, right, parent, ind, cond, dft_left}; } - - deleted_nodes_.resize(0); + deleted_nodes_.clear(); for (bst_node_t i = 1; i < param.num_nodes; ++i) { if (nodes_[i].IsDeleted()) { deleted_nodes_.push_back(i); } } + + auto& self = *this; + for (auto nid = 1; nid < n_nodes; ++nid) { + auto parent = self[nid].Parent(); + CHECK_NE(parent, RegTree::kInvalidNodeId); + self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid); + } CHECK_EQ(static_cast(deleted_nodes_.size()), param.num_deleted); } diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index a621fca46..5707b809c 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -44,6 +44,9 @@ class TreePruner: public TreeUpdater { auto& out = *p_out; out["train_param"] = toJson(param_); } + bool CanModifyTree() const override { + return true; + } // update the tree, do pruning void Update(HostDeviceVector *gpair, diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index c954712d7..4c6c256e4 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -36,6 +36,9 @@ class TreeRefresher: public TreeUpdater { char const* Name() const override { return "refresh"; } + bool CanModifyTree() const override { + return true; + } // update the tree, do pruning void Update(HostDeviceVector *gpair, DMatrix *p_fmat, diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index f3655c009..33be0f818 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -51,6 +51,22 @@ TEST(GBTree, SelectTreeMethod) { #endif // XGBOOST_USE_CUDA } +TEST(GBTree, WrongUpdater) { + size_t constexpr kRows = 17; + size_t constexpr kCols = 15; + + auto pp_dmat = CreateDMatrix(kRows, kCols, 0); + std::shared_ptr p_dmat {*pp_dmat}; + + p_dmat->Info().labels_.Resize(kRows); + + auto learner = std::unique_ptr(Learner::Create({p_dmat})); + // Hist can not be used for updating tree. + learner->SetParams(Args{{"tree_method", "hist"}, {"process_type", "update"}}); + ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error); + delete pp_dmat; +} + #ifdef XGBOOST_USE_CUDA TEST(GBTree, ChoosePredictor) { size_t constexpr kRows = 17; diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 0edf00e80..0c7d87e6c 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -225,8 +225,6 @@ TEST(Tree, JsonIO) { tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); Json j_tree{Object()}; tree.SaveModel(&j_tree); - std::stringstream ss; - Json::Dump(j_tree, &ss); auto tparam = j_tree["tree_param"]; ASSERT_EQ(get(tparam["num_feature"]), "0"); @@ -243,6 +241,23 @@ TEST(Tree, JsonIO) { RegTree loaded_tree; loaded_tree.LoadModel(j_tree); ASSERT_EQ(loaded_tree.param.num_nodes, 3); + + ASSERT_TRUE(loaded_tree == tree); + + auto left = tree[0].LeftChild(); + auto right = tree[0].RightChild(); + tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + tree.SaveModel(&j_tree); + + tree.ChangeToLeaf(1, 1.0f); + ASSERT_EQ(tree[1].LeftChild(), -1); + ASSERT_EQ(tree[1].RightChild(), -1); + tree.SaveModel(&j_tree); + loaded_tree.LoadModel(j_tree); + ASSERT_EQ(loaded_tree[1].LeftChild(), -1); + ASSERT_EQ(loaded_tree[1].RightChild(), -1); + ASSERT_TRUE(tree.Equal(loaded_tree)); } } // namespace xgboost