Enforce tree order in JSON. (#5974)
* Make JSON model IO more future proof by using tree id in model loading.
This commit is contained in:
parent
dde9c5aaff
commit
9c6e791e64
@ -1,6 +1,8 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2019 by Contributors
|
* Copyright 2019-2020 by Contributors
|
||||||
*/
|
*/
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "gbtree_model.h"
|
#include "gbtree_model.h"
|
||||||
@ -41,15 +43,14 @@ void GBTreeModel::SaveModel(Json* p_out) const {
|
|||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
||||||
out["gbtree_model_param"] = ToJson(param);
|
out["gbtree_model_param"] = ToJson(param);
|
||||||
std::vector<Json> trees_json;
|
std::vector<Json> trees_json(trees.size());
|
||||||
size_t t = 0;
|
|
||||||
for (auto const& tree : trees) {
|
for (size_t t = 0; t < trees.size(); ++t) {
|
||||||
|
auto const& tree = trees[t];
|
||||||
Json tree_json{Object()};
|
Json tree_json{Object()};
|
||||||
tree->SaveModel(&tree_json);
|
tree->SaveModel(&tree_json);
|
||||||
// The field is not used in XGBoost, but might be useful for external project.
|
tree_json["id"] = Integer(static_cast<Integer::Int>(t));
|
||||||
tree_json["id"] = Integer(t);
|
trees_json[t] = std::move(tree_json);
|
||||||
trees_json.emplace_back(tree_json);
|
|
||||||
t++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Json> tree_info_json(tree_info.size());
|
std::vector<Json> tree_info_json(tree_info.size());
|
||||||
@ -70,9 +71,10 @@ void GBTreeModel::LoadModel(Json const& in) {
|
|||||||
auto const& trees_json = get<Array const>(in["trees"]);
|
auto const& trees_json = get<Array const>(in["trees"]);
|
||||||
trees.resize(trees_json.size());
|
trees.resize(trees_json.size());
|
||||||
|
|
||||||
for (size_t t = 0; t < trees.size(); ++t) {
|
for (size_t t = 0; t < trees_json.size(); ++t) { // NOLINT
|
||||||
trees[t].reset( new RegTree() );
|
auto tree_id = get<Integer>(trees_json[t]["id"]);
|
||||||
trees[t]->LoadModel(trees_json[t]);
|
trees.at(tree_id).reset(new RegTree());
|
||||||
|
trees.at(tree_id)->LoadModel(trees_json[t]);
|
||||||
}
|
}
|
||||||
|
|
||||||
tree_info.resize(param.num_trees);
|
tree_info.resize(param.num_trees);
|
||||||
|
|||||||
@ -148,7 +148,16 @@ TEST(Learner, JsonModelIO) {
|
|||||||
Json out { Object() };
|
Json out { Object() };
|
||||||
learner->SaveModel(&out);
|
learner->SaveModel(&out);
|
||||||
|
|
||||||
learner->LoadModel(out);
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
|
||||||
|
std::ofstream fout (tmpdir.path + "/model.json");
|
||||||
|
fout << out;
|
||||||
|
fout.close();
|
||||||
|
|
||||||
|
auto loaded_str = common::LoadSequentialFile(tmpdir.path + "/model.json");
|
||||||
|
Json loaded = Json::Load(StringView{loaded_str.c_str(), loaded_str.size()});
|
||||||
|
|
||||||
|
learner->LoadModel(loaded);
|
||||||
learner->Configure();
|
learner->Configure();
|
||||||
|
|
||||||
Json new_in { Object() };
|
Json new_in { Object() };
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user