Check whether current updater can modify a tree. (#5406)
* Check whether current updater can modify a tree. * Fix tree model JSON IO for pruned trees.
This commit is contained in:
parent
b745b7acce
commit
ab7a46a1a4
@ -65,6 +65,7 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
|
||||
DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
|
||||
DMLC_DECLARE_FIELD(num_feature)
|
||||
.describe("Number of features used in tree construction.");
|
||||
DMLC_DECLARE_FIELD(num_deleted);
|
||||
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
|
||||
.describe("Size of leaf vector, reserved for vector tree");
|
||||
}
|
||||
@ -114,6 +115,7 @@ class RegTree : public Model {
|
||||
Node(int32_t cleft, int32_t cright, int32_t parent,
|
||||
uint32_t split_ind, float split_cond, bool default_left) :
|
||||
parent_{parent}, cleft_{cleft}, cright_{cright} {
|
||||
this->SetParent(parent_);
|
||||
this->SetSplit(split_ind, split_cond, default_left);
|
||||
}
|
||||
|
||||
@ -319,6 +321,13 @@ class RegTree : public Model {
|
||||
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
|
||||
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
|
||||
}
|
||||
/*!
|
||||
* \brief Compares whether 2 trees are equal from a user's perspective. The equality
|
||||
* compares only non-deleted nodes.
|
||||
*
|
||||
* \parm b The other tree.
|
||||
*/
|
||||
bool Equal(const RegTree& b) const;
|
||||
|
||||
/**
|
||||
* \brief Expands a leaf node into two additional leaf nodes.
|
||||
|
||||
@ -40,6 +40,13 @@ class TreeUpdater : public Configurable {
|
||||
* \param args arguments to the objective function.
|
||||
*/
|
||||
virtual void Configure(const Args& args) = 0;
|
||||
/*! \brief Whether this updater can be used for updating existing trees.
|
||||
*
|
||||
* Some updaters are used for building new trees (like `hist`), while some others are
|
||||
* used for modifying existing trees (like `prune`). Return true if it can modify
|
||||
* existing trees.
|
||||
*/
|
||||
virtual bool CanModifyTree() const { return false; }
|
||||
/*!
|
||||
* \brief perform update to the tree models
|
||||
* \param gpair the gradient pair statistics of the data
|
||||
|
||||
@ -273,6 +273,12 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
|
||||
new_trees.push_back(ptr.get());
|
||||
ret->push_back(std::move(ptr));
|
||||
} else if (tparam_.process_type == TreeProcessType::kUpdate) {
|
||||
for (auto const& up : updaters_) {
|
||||
CHECK(up->CanModifyTree())
|
||||
<< "Updater: `" << up->Name() << "` "
|
||||
<< "can not be used to modify existing trees. "
|
||||
<< "Set `process_type` to `default` if you want to build new trees.";
|
||||
}
|
||||
CHECK_LT(model_.trees.size(), model_.trees_to_update.size());
|
||||
// move an existing tree from trees_to_update
|
||||
auto t = std::move(model_.trees_to_update[model_.trees.size() +
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <stack>
|
||||
|
||||
#include "param.h"
|
||||
#include "../common/common.h"
|
||||
@ -618,6 +619,32 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
return result;
|
||||
}
|
||||
|
||||
bool RegTree::Equal(const RegTree& b) const {
|
||||
if (NumExtraNodes() != b.NumExtraNodes()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::stack<bst_node_t> nodes;
|
||||
nodes.push(0);
|
||||
auto& self = *this;
|
||||
while (!nodes.empty()) {
|
||||
auto nid = nodes.top();
|
||||
nodes.pop();
|
||||
if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) {
|
||||
return false;
|
||||
}
|
||||
auto left = self[nid].LeftChild();
|
||||
auto right = self[nid].RightChild();
|
||||
if (left != RegTree::kInvalidNodeId) {
|
||||
nodes.push(left);
|
||||
}
|
||||
if (right != RegTree::kInvalidNodeId) {
|
||||
nodes.push(right);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RegTree::Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
||||
nodes_.resize(param.num_nodes);
|
||||
@ -673,6 +700,9 @@ void RegTree::LoadModel(Json const& in) {
|
||||
auto const& default_left = get<Array const>(in["default_left"]);
|
||||
CHECK_EQ(default_left.size(), n_nodes);
|
||||
|
||||
stats_.clear();
|
||||
nodes_.clear();
|
||||
|
||||
stats_.resize(n_nodes);
|
||||
nodes_.resize(n_nodes);
|
||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||
@ -692,13 +722,19 @@ void RegTree::LoadModel(Json const& in) {
|
||||
n = Node{left, right, parent, ind, cond, dft_left};
|
||||
}
|
||||
|
||||
|
||||
deleted_nodes_.resize(0);
|
||||
deleted_nodes_.clear();
|
||||
for (bst_node_t i = 1; i < param.num_nodes; ++i) {
|
||||
if (nodes_[i].IsDeleted()) {
|
||||
deleted_nodes_.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto& self = *this;
|
||||
for (auto nid = 1; nid < n_nodes; ++nid) {
|
||||
auto parent = self[nid].Parent();
|
||||
CHECK_NE(parent, RegTree::kInvalidNodeId);
|
||||
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
|
||||
}
|
||||
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted);
|
||||
}
|
||||
|
||||
|
||||
@ -44,6 +44,9 @@ class TreePruner: public TreeUpdater {
|
||||
auto& out = *p_out;
|
||||
out["train_param"] = toJson(param_);
|
||||
}
|
||||
bool CanModifyTree() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
// update the tree, do pruning
|
||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||
|
||||
@ -36,6 +36,9 @@ class TreeRefresher: public TreeUpdater {
|
||||
char const* Name() const override {
|
||||
return "refresh";
|
||||
}
|
||||
bool CanModifyTree() const override {
|
||||
return true;
|
||||
}
|
||||
// update the tree, do pruning
|
||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *p_fmat,
|
||||
|
||||
@ -51,6 +51,22 @@ TEST(GBTree, SelectTreeMethod) {
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
}
|
||||
|
||||
TEST(GBTree, WrongUpdater) {
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
|
||||
auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
|
||||
std::shared_ptr<DMatrix> p_dmat {*pp_dmat};
|
||||
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
|
||||
// Hist can not be used for updating tree.
|
||||
learner->SetParams(Args{{"tree_method", "hist"}, {"process_type", "update"}});
|
||||
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
|
||||
delete pp_dmat;
|
||||
}
|
||||
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
TEST(GBTree, ChoosePredictor) {
|
||||
size_t constexpr kRows = 17;
|
||||
|
||||
@ -225,8 +225,6 @@ TEST(Tree, JsonIO) {
|
||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
Json j_tree{Object()};
|
||||
tree.SaveModel(&j_tree);
|
||||
std::stringstream ss;
|
||||
Json::Dump(j_tree, &ss);
|
||||
|
||||
auto tparam = j_tree["tree_param"];
|
||||
ASSERT_EQ(get<String>(tparam["num_feature"]), "0");
|
||||
@ -243,6 +241,23 @@ TEST(Tree, JsonIO) {
|
||||
RegTree loaded_tree;
|
||||
loaded_tree.LoadModel(j_tree);
|
||||
ASSERT_EQ(loaded_tree.param.num_nodes, 3);
|
||||
|
||||
ASSERT_TRUE(loaded_tree == tree);
|
||||
|
||||
auto left = tree[0].LeftChild();
|
||||
auto right = tree[0].RightChild();
|
||||
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.SaveModel(&j_tree);
|
||||
|
||||
tree.ChangeToLeaf(1, 1.0f);
|
||||
ASSERT_EQ(tree[1].LeftChild(), -1);
|
||||
ASSERT_EQ(tree[1].RightChild(), -1);
|
||||
tree.SaveModel(&j_tree);
|
||||
loaded_tree.LoadModel(j_tree);
|
||||
ASSERT_EQ(loaded_tree[1].LeftChild(), -1);
|
||||
ASSERT_EQ(loaded_tree[1].RightChild(), -1);
|
||||
ASSERT_TRUE(tree.Equal(loaded_tree));
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user