From 9c6e791e64afc8b10365ecc1ec2914dbe3b6554c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 5 Aug 2020 16:44:52 +0800 Subject: [PATCH] Enforce tree order in JSON. (#5974) * Make JSON model IO more future proof by using tree id in model loading. --- src/gbm/gbtree_model.cc | 24 +++++++++++++----------- tests/cpp/test_learner.cc | 11 ++++++++++- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index a53346797..8ebd8284c 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -1,6 +1,8 @@ /*! - * Copyright 2019 by Contributors + * Copyright 2019-2020 by Contributors */ +#include + #include "xgboost/json.h" #include "xgboost/logging.h" #include "gbtree_model.h" @@ -41,15 +43,14 @@ void GBTreeModel::SaveModel(Json* p_out) const { auto& out = *p_out; CHECK_EQ(param.num_trees, static_cast(trees.size())); out["gbtree_model_param"] = ToJson(param); - std::vector trees_json; - size_t t = 0; - for (auto const& tree : trees) { + std::vector trees_json(trees.size()); + + for (size_t t = 0; t < trees.size(); ++t) { + auto const& tree = trees[t]; Json tree_json{Object()}; tree->SaveModel(&tree_json); - // The field is not used in XGBoost, but might be useful for external project. - tree_json["id"] = Integer(t); - trees_json.emplace_back(tree_json); - t++; + tree_json["id"] = Integer(static_cast(t)); + trees_json[t] = std::move(tree_json); } std::vector tree_info_json(tree_info.size()); @@ -70,9 +71,10 @@ void GBTreeModel::LoadModel(Json const& in) { auto const& trees_json = get(in["trees"]); trees.resize(trees_json.size()); - for (size_t t = 0; t < trees.size(); ++t) { - trees[t].reset( new RegTree() ); - trees[t]->LoadModel(trees_json[t]); + for (size_t t = 0; t < trees_json.size(); ++t) { // NOLINT + auto tree_id = get(trees_json[t]["id"]); + trees.at(tree_id).reset(new RegTree()); + trees.at(tree_id)->LoadModel(trees_json[t]); } tree_info.resize(param.num_trees); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 7d473f00c..56e4a95ec 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -148,7 +148,16 @@ TEST(Learner, JsonModelIO) { Json out { Object() }; learner->SaveModel(&out); - learner->LoadModel(out); + dmlc::TemporaryDirectory tmpdir; + + std::ofstream fout (tmpdir.path + "/model.json"); + fout << out; + fout.close(); + + auto loaded_str = common::LoadSequentialFile(tmpdir.path + "/model.json"); + Json loaded = Json::Load(StringView{loaded_str.c_str(), loaded_str.size()}); + + learner->LoadModel(loaded); learner->Configure(); Json new_in { Object() };