Check whether current updater can modify a tree. (#5406)
* Check whether current updater can modify a tree. * Fix tree model JSON IO for pruned trees.
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <stack>
|
||||
|
||||
#include "param.h"
|
||||
#include "../common/common.h"
|
||||
@@ -618,6 +619,32 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
return result;
|
||||
}
|
||||
|
||||
bool RegTree::Equal(const RegTree& b) const {
|
||||
if (NumExtraNodes() != b.NumExtraNodes()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::stack<bst_node_t> nodes;
|
||||
nodes.push(0);
|
||||
auto& self = *this;
|
||||
while (!nodes.empty()) {
|
||||
auto nid = nodes.top();
|
||||
nodes.pop();
|
||||
if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) {
|
||||
return false;
|
||||
}
|
||||
auto left = self[nid].LeftChild();
|
||||
auto right = self[nid].RightChild();
|
||||
if (left != RegTree::kInvalidNodeId) {
|
||||
nodes.push(left);
|
||||
}
|
||||
if (right != RegTree::kInvalidNodeId) {
|
||||
nodes.push(right);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RegTree::Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
||||
nodes_.resize(param.num_nodes);
|
||||
@@ -673,6 +700,9 @@ void RegTree::LoadModel(Json const& in) {
|
||||
auto const& default_left = get<Array const>(in["default_left"]);
|
||||
CHECK_EQ(default_left.size(), n_nodes);
|
||||
|
||||
stats_.clear();
|
||||
nodes_.clear();
|
||||
|
||||
stats_.resize(n_nodes);
|
||||
nodes_.resize(n_nodes);
|
||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||
@@ -692,13 +722,19 @@ void RegTree::LoadModel(Json const& in) {
|
||||
n = Node{left, right, parent, ind, cond, dft_left};
|
||||
}
|
||||
|
||||
|
||||
deleted_nodes_.resize(0);
|
||||
deleted_nodes_.clear();
|
||||
for (bst_node_t i = 1; i < param.num_nodes; ++i) {
|
||||
if (nodes_[i].IsDeleted()) {
|
||||
deleted_nodes_.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto& self = *this;
|
||||
for (auto nid = 1; nid < n_nodes; ++nid) {
|
||||
auto parent = self[nid].Parent();
|
||||
CHECK_NE(parent, RegTree::kInvalidNodeId);
|
||||
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
|
||||
}
|
||||
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user