Check whether current updater can modify a tree. (#5406)

* Check whether current updater can modify a tree.

* Fix tree model JSON IO for pruned trees.
This commit is contained in:
Jiaming Yuan
2020-03-14 09:24:08 +08:00
committed by GitHub
parent b745b7acce
commit ab7a46a1a4
8 changed files with 99 additions and 4 deletions

View File

@@ -14,6 +14,7 @@
#include <limits>
#include <cmath>
#include <iomanip>
#include <stack>
#include "param.h"
#include "../common/common.h"
@@ -618,6 +619,32 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
return result;
}
bool RegTree::Equal(const RegTree& b) const {
if (NumExtraNodes() != b.NumExtraNodes()) {
return false;
}
std::stack<bst_node_t> nodes;
nodes.push(0);
auto& self = *this;
while (!nodes.empty()) {
auto nid = nodes.top();
nodes.pop();
if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) {
return false;
}
auto left = self[nid].LeftChild();
auto right = self[nid].RightChild();
if (left != RegTree::kInvalidNodeId) {
nodes.push(left);
}
if (right != RegTree::kInvalidNodeId) {
nodes.push(right);
}
}
return true;
}
void RegTree::Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam));
nodes_.resize(param.num_nodes);
@@ -673,6 +700,9 @@ void RegTree::LoadModel(Json const& in) {
auto const& default_left = get<Array const>(in["default_left"]);
CHECK_EQ(default_left.size(), n_nodes);
stats_.clear();
nodes_.clear();
stats_.resize(n_nodes);
nodes_.resize(n_nodes);
for (int32_t i = 0; i < n_nodes; ++i) {
@@ -692,13 +722,19 @@ void RegTree::LoadModel(Json const& in) {
n = Node{left, right, parent, ind, cond, dft_left};
}
deleted_nodes_.resize(0);
deleted_nodes_.clear();
for (bst_node_t i = 1; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) {
deleted_nodes_.push_back(i);
}
}
auto& self = *this;
for (auto nid = 1; nid < n_nodes; ++nid) {
auto parent = self[nid].Parent();
CHECK_NE(parent, RegTree::kInvalidNodeId);
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
}
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted);
}