Small refinements for JSON model. (#5112)

* Naming consistency.

* Remove duplicated test.
This commit is contained in:
Jiaming Yuan 2019-12-11 19:49:01 +08:00 committed by GitHub
parent 208ab3b1ff
commit ad4a1c732c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 19 additions and 30 deletions

View File

@ -40,7 +40,7 @@ void GBTreeModel::Load(dmlc::Stream* fi) {
void GBTreeModel::SaveModel(Json* p_out) const { void GBTreeModel::SaveModel(Json* p_out) const {
auto& out = *p_out; auto& out = *p_out;
CHECK_EQ(param.num_trees, static_cast<int>(trees.size())); CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
out["model_param"] = toJson(param); out["gbtree_model_param"] = toJson(param);
std::vector<Json> trees_json; std::vector<Json> trees_json;
size_t t = 0; size_t t = 0;
for (auto const& tree : trees) { for (auto const& tree : trees) {
@ -61,7 +61,7 @@ void GBTreeModel::SaveModel(Json* p_out) const {
} }
void GBTreeModel::LoadModel(Json const& in) { void GBTreeModel::LoadModel(Json const& in) {
fromJson(in["model_param"], &param); fromJson(in["gbtree_model_param"], &param);
trees.clear(); trees.clear();
trees_to_update.clear(); trees_to_update.clear();

View File

@ -269,7 +269,7 @@ class LearnerImpl : public Learner {
void LoadModel(Json const& in) override { void LoadModel(Json const& in) override {
CHECK(IsA<Object>(in)); CHECK(IsA<Object>(in));
Version::Load(in, false); Version::Load(in, false);
auto const& learner = get<Object>(in["Learner"]); auto const& learner = get<Object>(in["learner"]);
mparam_.FromJson(learner.at("learner_model_param")); mparam_.FromJson(learner.at("learner_model_param"));
auto const& objective_fn = learner.at("objective"); auto const& objective_fn = learner.at("objective");
@ -305,8 +305,8 @@ class LearnerImpl : public Learner {
Version::Save(p_out); Version::Save(p_out);
Json& out { *p_out }; Json& out { *p_out };
out["Learner"] = Object(); out["learner"] = Object();
auto& learner = out["Learner"]; auto& learner = out["learner"];
learner["learner_model_param"] = mparam_.ToJson(); learner["learner_model_param"] = mparam_.ToJson();
learner["gradient_booster"] = Object(); learner["gradient_booster"] = Object();

View File

@ -35,22 +35,14 @@ TEST(GBLinear, Json_IO) {
std::string model_str; std::string model_str;
Json::Dump(model, &model_str); Json::Dump(model, &model_str);
model = Json::Load({model_str.c_str(), model_str.size()}); model = Json::Load(StringView{model_str.c_str(), model_str.size()});
ASSERT_TRUE(IsA<Object>(model)); ASSERT_TRUE(IsA<Object>(model));
model = model["model"];
{ {
model = model["model"];
auto weights = get<Array>(model["weights"]); auto weights = get<Array>(model["weights"]);
ASSERT_EQ(weights.size(), 17); ASSERT_EQ(weights.size(), 17);
} }
{
model = Json::Load({model_str.c_str(), model_str.size()});
model = model["model"];
auto weights = get<Array>(model["weights"]);
ASSERT_EQ(weights.size(), 17); // 16 + 1 (bias)
}
} }
} // namespace gbm } // namespace gbm

View File

@ -64,7 +64,7 @@ TEST(GBTree, ChoosePredictor) {
} }
ASSERT_TRUE(data.HostCanWrite()); ASSERT_TRUE(data.HostCanWrite());
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string fname = tempdir.path + "/model_para.bst"; const std::string fname = tempdir.path + "/model_param.bst";
{ {
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w")); std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
@ -117,17 +117,15 @@ TEST(GBTree, Json_IO) {
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) }; CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) };
Json model {Object()}; Json model {Object()};
model["model"] = Object();
auto& j_model = model["model"];
gbm->SaveModel(&j_model); gbm->SaveModel(&model);
std::stringstream ss; std::string model_str;
Json::Dump(model, &ss); Json::Dump(model, &model_str);
auto model_str = ss.str(); auto loaded_model = Json::Load(StringView{model_str.c_str(), model_str.size()});
model = Json::Load({model_str.c_str(), model_str.size()}); ASSERT_EQ(get<String>(loaded_model["name"]), "gbtree");
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree"); ASSERT_TRUE(IsA<Object>(loaded_model["model"]["gbtree_model_param"]));
} }
TEST(Dart, Json_IO) { TEST(Dart, Json_IO) {

View File

@ -143,7 +143,7 @@ TEST(Learner, Json_ModelIO) {
for (int32_t iter = 0; iter < kIters; ++iter) { for (int32_t iter = 0; iter < kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get()); learner->UpdateOneIter(iter, p_dmat.get());
} }
learner->SetAttr("bset_score", "15.2"); learner->SetAttr("best_score", "15.2");
Json out { Object() }; Json out { Object() };
learner->SaveModel(&out); learner->SaveModel(&out);
@ -153,8 +153,8 @@ TEST(Learner, Json_ModelIO) {
learner->Configure(); learner->Configure();
learner->SaveModel(&new_in); learner->SaveModel(&new_in);
ASSERT_TRUE(IsA<Object>(out["Learner"]["attributes"])); ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
ASSERT_EQ(get<Object>(out["Learner"]["attributes"]).size(), 1); ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1);
ASSERT_EQ(out, new_in); ASSERT_EQ(out, new_in);
} }

View File

@ -13,7 +13,6 @@
#include "xgboost/json.h" #include "xgboost/json.h"
#include "../../../src/data/sparse_page_source.h" #include "../../../src/data/sparse_page_source.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../../../src/tree/updater_gpu_hist.cu" #include "../../../src/tree/updater_gpu_hist.cu"
#include "../../../src/tree/updater_gpu_common.cuh" #include "../../../src/tree/updater_gpu_common.cuh"
#include "../../../src/common/common.h" #include "../../../src/common/common.h"

View File

@ -213,12 +213,12 @@ class TestModels(unittest.TestCase):
with open('./model.json', 'r') as fd: with open('./model.json', 'r') as fd:
j_model = json.load(fd) j_model = json.load(fd)
assert isinstance(j_model['Learner'], dict) assert isinstance(j_model['learner'], dict)
bst = xgb.Booster(model_file='./model.json') bst = xgb.Booster(model_file='./model.json')
with open('./model.json', 'r') as fd: with open('./model.json', 'r') as fd:
j_model = json.load(fd) j_model = json.load(fd)
assert isinstance(j_model['Learner'], dict) assert isinstance(j_model['learner'], dict)
os.remove('model.json') os.remove('model.json')