Model IO in JSON. (#5110)
This commit is contained in:
@@ -8,12 +8,15 @@
|
||||
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/json.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
|
||||
#include "param.h"
|
||||
#include "../common/common.h"
|
||||
|
||||
namespace xgboost {
|
||||
// register tree parameter
|
||||
@@ -615,7 +618,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
return result;
|
||||
}
|
||||
|
||||
void RegTree::LoadModel(dmlc::Stream* fi) {
|
||||
void RegTree::Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
@@ -633,11 +636,7 @@ void RegTree::LoadModel(dmlc::Stream* fi) {
|
||||
}
|
||||
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
|
||||
}
|
||||
/*!
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
*/
|
||||
void RegTree::SaveModel(dmlc::Stream* fo) const {
|
||||
void RegTree::Save(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
||||
fo->Write(¶m, sizeof(TreeParam));
|
||||
@@ -646,6 +645,114 @@ void RegTree::SaveModel(dmlc::Stream* fo) const {
|
||||
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
|
||||
}
|
||||
|
||||
void RegTree::LoadModel(Json const& in) {
|
||||
fromJson(in["tree_param"], ¶m);
|
||||
auto n_nodes = param.num_nodes;
|
||||
CHECK_NE(n_nodes, 0);
|
||||
// stats
|
||||
auto const& loss_changes = get<Array const>(in["loss_changes"]);
|
||||
CHECK_EQ(loss_changes.size(), n_nodes);
|
||||
auto const& sum_hessian = get<Array const>(in["sum_hessian"]);
|
||||
CHECK_EQ(sum_hessian.size(), n_nodes);
|
||||
auto const& base_weights = get<Array const>(in["base_weights"]);
|
||||
CHECK_EQ(base_weights.size(), n_nodes);
|
||||
auto const& leaf_child_counts = get<Array const>(in["leaf_child_counts"]);
|
||||
CHECK_EQ(leaf_child_counts.size(), n_nodes);
|
||||
// nodes
|
||||
auto const& lefts = get<Array const>(in["left_children"]);
|
||||
CHECK_EQ(lefts.size(), n_nodes);
|
||||
auto const& rights = get<Array const>(in["right_children"]);
|
||||
CHECK_EQ(rights.size(), n_nodes);
|
||||
auto const& parents = get<Array const>(in["parents"]);
|
||||
CHECK_EQ(parents.size(), n_nodes);
|
||||
auto const& indices = get<Array const>(in["split_indices"]);
|
||||
CHECK_EQ(indices.size(), n_nodes);
|
||||
auto const& conds = get<Array const>(in["split_conditions"]);
|
||||
CHECK_EQ(conds.size(), n_nodes);
|
||||
auto const& default_left = get<Array const>(in["default_left"]);
|
||||
CHECK_EQ(default_left.size(), n_nodes);
|
||||
|
||||
stats_.resize(n_nodes);
|
||||
nodes_.resize(n_nodes);
|
||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||
auto& s = stats_[i];
|
||||
s.loss_chg = get<Number const>(loss_changes[i]);
|
||||
s.sum_hess = get<Number const>(sum_hessian[i]);
|
||||
s.base_weight = get<Number const>(base_weights[i]);
|
||||
s.leaf_child_cnt = get<Integer const>(leaf_child_counts[i]);
|
||||
|
||||
auto& n = nodes_[i];
|
||||
auto left = get<Integer const>(lefts[i]);
|
||||
auto right = get<Integer const>(rights[i]);
|
||||
auto parent = get<Integer const>(parents[i]);
|
||||
auto ind = get<Integer const>(indices[i]);
|
||||
auto cond = get<Number const>(conds[i]);
|
||||
auto dft_left = get<Boolean const>(default_left[i]);
|
||||
n = Node(left, right, parent, ind, cond, dft_left);
|
||||
}
|
||||
|
||||
|
||||
deleted_nodes_.resize(0);
|
||||
for (bst_node_t i = 1; i < param.num_nodes; ++i) {
|
||||
if (nodes_[i].IsDeleted()) {
|
||||
deleted_nodes_.push_back(i);
|
||||
}
|
||||
}
|
||||
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted);
|
||||
}
|
||||
|
||||
void RegTree::SaveModel(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
||||
out["tree_param"] = toJson(param);
|
||||
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes));
|
||||
using I = Integer::Int;
|
||||
auto n_nodes = param.num_nodes;
|
||||
|
||||
// stats
|
||||
std::vector<Json> loss_changes(n_nodes);
|
||||
std::vector<Json> sum_hessian(n_nodes);
|
||||
std::vector<Json> base_weights(n_nodes);
|
||||
std::vector<Json> leaf_child_counts(n_nodes);
|
||||
|
||||
// nodes
|
||||
std::vector<Json> lefts(n_nodes);
|
||||
std::vector<Json> rights(n_nodes);
|
||||
std::vector<Json> parents(n_nodes);
|
||||
std::vector<Json> indices(n_nodes);
|
||||
std::vector<Json> conds(n_nodes);
|
||||
std::vector<Json> default_left(n_nodes);
|
||||
|
||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||
auto const& s = stats_[i];
|
||||
loss_changes[i] = s.loss_chg;
|
||||
sum_hessian[i] = s.sum_hess;
|
||||
base_weights[i] = s.base_weight;
|
||||
leaf_child_counts[i] = static_cast<I>(s.leaf_child_cnt);
|
||||
|
||||
auto const& n = nodes_[i];
|
||||
lefts[i] = static_cast<I>(n.LeftChild());
|
||||
rights[i] = static_cast<I>(n.RightChild());
|
||||
parents[i] = static_cast<I>(n.Parent());
|
||||
indices[i] = static_cast<I>(n.SplitIndex());
|
||||
conds[i] = n.SplitCond();
|
||||
default_left[i] = n.DefaultLeft();
|
||||
}
|
||||
|
||||
out["loss_changes"] = std::move(loss_changes);
|
||||
out["sum_hessian"] = std::move(sum_hessian);
|
||||
out["base_weights"] = std::move(base_weights);
|
||||
out["leaf_child_counts"] = std::move(leaf_child_counts);
|
||||
|
||||
out["left_children"] = std::move(lefts);
|
||||
out["right_children"] = std::move(rights);
|
||||
out["parents"] = std::move(parents);
|
||||
out["split_indices"] = std::move(indices);
|
||||
out["split_conditions"] = std::move(conds);
|
||||
out["default_left"] = std::move(default_left);
|
||||
}
|
||||
|
||||
void RegTree::FillNodeMeanValues() {
|
||||
size_t num_nodes = this->param.num_nodes;
|
||||
if (this->node_mean_values_.size() == num_nodes) {
|
||||
|
||||
Reference in New Issue
Block a user