Check whether current updater can modify a tree. (#5406)

* Check whether current updater can modify a tree.

* Fix tree model JSON IO for pruned trees.
This commit is contained in:
Jiaming Yuan 2020-03-14 09:24:08 +08:00 committed by GitHub
parent b745b7acce
commit ab7a46a1a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 99 additions and 4 deletions

View File

@ -65,6 +65,7 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1); DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
DMLC_DECLARE_FIELD(num_feature) DMLC_DECLARE_FIELD(num_feature)
.describe("Number of features used in tree construction."); .describe("Number of features used in tree construction.");
DMLC_DECLARE_FIELD(num_deleted);
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0) DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
.describe("Size of leaf vector, reserved for vector tree"); .describe("Size of leaf vector, reserved for vector tree");
} }
@ -114,6 +115,7 @@ class RegTree : public Model {
Node(int32_t cleft, int32_t cright, int32_t parent, Node(int32_t cleft, int32_t cright, int32_t parent,
uint32_t split_ind, float split_cond, bool default_left) : uint32_t split_ind, float split_cond, bool default_left) :
parent_{parent}, cleft_{cleft}, cright_{cright} { parent_{parent}, cleft_{cleft}, cright_{cright} {
this->SetParent(parent_);
this->SetSplit(split_ind, split_cond, default_left); this->SetSplit(split_ind, split_cond, default_left);
} }
@ -319,6 +321,13 @@ class RegTree : public Model {
return nodes_ == b.nodes_ && stats_ == b.stats_ && return nodes_ == b.nodes_ && stats_ == b.stats_ &&
deleted_nodes_ == b.deleted_nodes_ && param == b.param; deleted_nodes_ == b.deleted_nodes_ && param == b.param;
} }
/*!
* \brief Compares whether 2 trees are equal from a user's perspective. The equality
* compares only non-deleted nodes.
*
* \parm b The other tree.
*/
bool Equal(const RegTree& b) const;
/** /**
* \brief Expands a leaf node into two additional leaf nodes. * \brief Expands a leaf node into two additional leaf nodes.

View File

@ -40,6 +40,13 @@ class TreeUpdater : public Configurable {
* \param args arguments to the objective function. * \param args arguments to the objective function.
*/ */
virtual void Configure(const Args& args) = 0; virtual void Configure(const Args& args) = 0;
/*! \brief Whether this updater can be used for updating existing trees.
*
* Some updaters are used for building new trees (like `hist`), while some others are
* used for modifying existing trees (like `prune`). Return true if it can modify
* existing trees.
*/
virtual bool CanModifyTree() const { return false; }
/*! /*!
* \brief perform update to the tree models * \brief perform update to the tree models
* \param gpair the gradient pair statistics of the data * \param gpair the gradient pair statistics of the data

View File

@ -273,6 +273,12 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
new_trees.push_back(ptr.get()); new_trees.push_back(ptr.get());
ret->push_back(std::move(ptr)); ret->push_back(std::move(ptr));
} else if (tparam_.process_type == TreeProcessType::kUpdate) { } else if (tparam_.process_type == TreeProcessType::kUpdate) {
for (auto const& up : updaters_) {
CHECK(up->CanModifyTree())
<< "Updater: `" << up->Name() << "` "
<< "can not be used to modify existing trees. "
<< "Set `process_type` to `default` if you want to build new trees.";
}
CHECK_LT(model_.trees.size(), model_.trees_to_update.size()); CHECK_LT(model_.trees.size(), model_.trees_to_update.size());
// move an existing tree from trees_to_update // move an existing tree from trees_to_update
auto t = std::move(model_.trees_to_update[model_.trees.size() + auto t = std::move(model_.trees_to_update[model_.trees.size() +

View File

@ -14,6 +14,7 @@
#include <limits> #include <limits>
#include <cmath> #include <cmath>
#include <iomanip> #include <iomanip>
#include <stack>
#include "param.h" #include "param.h"
#include "../common/common.h" #include "../common/common.h"
@ -618,6 +619,32 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
return result; return result;
} }
bool RegTree::Equal(const RegTree& b) const {
if (NumExtraNodes() != b.NumExtraNodes()) {
return false;
}
std::stack<bst_node_t> nodes;
nodes.push(0);
auto& self = *this;
while (!nodes.empty()) {
auto nid = nodes.top();
nodes.pop();
if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) {
return false;
}
auto left = self[nid].LeftChild();
auto right = self[nid].RightChild();
if (left != RegTree::kInvalidNodeId) {
nodes.push(left);
}
if (right != RegTree::kInvalidNodeId) {
nodes.push(right);
}
}
return true;
}
void RegTree::Load(dmlc::Stream* fi) { void RegTree::Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam)); CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam));
nodes_.resize(param.num_nodes); nodes_.resize(param.num_nodes);
@ -673,6 +700,9 @@ void RegTree::LoadModel(Json const& in) {
auto const& default_left = get<Array const>(in["default_left"]); auto const& default_left = get<Array const>(in["default_left"]);
CHECK_EQ(default_left.size(), n_nodes); CHECK_EQ(default_left.size(), n_nodes);
stats_.clear();
nodes_.clear();
stats_.resize(n_nodes); stats_.resize(n_nodes);
nodes_.resize(n_nodes); nodes_.resize(n_nodes);
for (int32_t i = 0; i < n_nodes; ++i) { for (int32_t i = 0; i < n_nodes; ++i) {
@ -692,13 +722,19 @@ void RegTree::LoadModel(Json const& in) {
n = Node{left, right, parent, ind, cond, dft_left}; n = Node{left, right, parent, ind, cond, dft_left};
} }
deleted_nodes_.clear();
deleted_nodes_.resize(0);
for (bst_node_t i = 1; i < param.num_nodes; ++i) { for (bst_node_t i = 1; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) { if (nodes_[i].IsDeleted()) {
deleted_nodes_.push_back(i); deleted_nodes_.push_back(i);
} }
} }
auto& self = *this;
for (auto nid = 1; nid < n_nodes; ++nid) {
auto parent = self[nid].Parent();
CHECK_NE(parent, RegTree::kInvalidNodeId);
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
}
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted); CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted);
} }

View File

@ -44,6 +44,9 @@ class TreePruner: public TreeUpdater {
auto& out = *p_out; auto& out = *p_out;
out["train_param"] = toJson(param_); out["train_param"] = toJson(param_);
} }
bool CanModifyTree() const override {
return true;
}
// update the tree, do pruning // update the tree, do pruning
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair,

View File

@ -36,6 +36,9 @@ class TreeRefresher: public TreeUpdater {
char const* Name() const override { char const* Name() const override {
return "refresh"; return "refresh";
} }
bool CanModifyTree() const override {
return true;
}
// update the tree, do pruning // update the tree, do pruning
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *p_fmat, DMatrix *p_fmat,

View File

@ -51,6 +51,22 @@ TEST(GBTree, SelectTreeMethod) {
#endif // XGBOOST_USE_CUDA #endif // XGBOOST_USE_CUDA
} }
TEST(GBTree, WrongUpdater) {
size_t constexpr kRows = 17;
size_t constexpr kCols = 15;
auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
std::shared_ptr<DMatrix> p_dmat {*pp_dmat};
p_dmat->Info().labels_.Resize(kRows);
auto learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
// Hist can not be used for updating tree.
learner->SetParams(Args{{"tree_method", "hist"}, {"process_type", "update"}});
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
delete pp_dmat;
}
#ifdef XGBOOST_USE_CUDA #ifdef XGBOOST_USE_CUDA
TEST(GBTree, ChoosePredictor) { TEST(GBTree, ChoosePredictor) {
size_t constexpr kRows = 17; size_t constexpr kRows = 17;

View File

@ -225,8 +225,6 @@ TEST(Tree, JsonIO) {
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
Json j_tree{Object()}; Json j_tree{Object()};
tree.SaveModel(&j_tree); tree.SaveModel(&j_tree);
std::stringstream ss;
Json::Dump(j_tree, &ss);
auto tparam = j_tree["tree_param"]; auto tparam = j_tree["tree_param"];
ASSERT_EQ(get<String>(tparam["num_feature"]), "0"); ASSERT_EQ(get<String>(tparam["num_feature"]), "0");
@ -243,6 +241,23 @@ TEST(Tree, JsonIO) {
RegTree loaded_tree; RegTree loaded_tree;
loaded_tree.LoadModel(j_tree); loaded_tree.LoadModel(j_tree);
ASSERT_EQ(loaded_tree.param.num_nodes, 3); ASSERT_EQ(loaded_tree.param.num_nodes, 3);
ASSERT_TRUE(loaded_tree == tree);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.SaveModel(&j_tree);
tree.ChangeToLeaf(1, 1.0f);
ASSERT_EQ(tree[1].LeftChild(), -1);
ASSERT_EQ(tree[1].RightChild(), -1);
tree.SaveModel(&j_tree);
loaded_tree.LoadModel(j_tree);
ASSERT_EQ(loaded_tree[1].LeftChild(), -1);
ASSERT_EQ(loaded_tree[1].RightChild(), -1);
ASSERT_TRUE(tree.Equal(loaded_tree));
} }
} // namespace xgboost } // namespace xgboost