Small refinements for JSON model. (#5112)

* Naming consistency.

* Remove duplicated test.
This commit is contained in:
Jiaming Yuan 2019-12-11 19:49:01 +08:00 committed by GitHub
parent 208ab3b1ff
commit ad4a1c732c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 19 additions and 30 deletions

View File

@ -40,7 +40,7 @@ void GBTreeModel::Load(dmlc::Stream* fi) {
void GBTreeModel::SaveModel(Json* p_out) const {
auto& out = *p_out;
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
out["model_param"] = toJson(param);
out["gbtree_model_param"] = toJson(param);
std::vector<Json> trees_json;
size_t t = 0;
for (auto const& tree : trees) {
@ -61,7 +61,7 @@ void GBTreeModel::SaveModel(Json* p_out) const {
}
void GBTreeModel::LoadModel(Json const& in) {
fromJson(in["model_param"], &param);
fromJson(in["gbtree_model_param"], &param);
trees.clear();
trees_to_update.clear();

View File

@ -269,7 +269,7 @@ class LearnerImpl : public Learner {
void LoadModel(Json const& in) override {
CHECK(IsA<Object>(in));
Version::Load(in, false);
auto const& learner = get<Object>(in["Learner"]);
auto const& learner = get<Object>(in["learner"]);
mparam_.FromJson(learner.at("learner_model_param"));
auto const& objective_fn = learner.at("objective");
@ -305,8 +305,8 @@ class LearnerImpl : public Learner {
Version::Save(p_out);
Json& out { *p_out };
out["Learner"] = Object();
auto& learner = out["Learner"];
out["learner"] = Object();
auto& learner = out["learner"];
learner["learner_model_param"] = mparam_.ToJson();
learner["gradient_booster"] = Object();

View File

@ -35,22 +35,14 @@ TEST(GBLinear, Json_IO) {
std::string model_str;
Json::Dump(model, &model_str);
model = Json::Load({model_str.c_str(), model_str.size()});
model = Json::Load(StringView{model_str.c_str(), model_str.size()});
ASSERT_TRUE(IsA<Object>(model));
model = model["model"];
{
model = model["model"];
auto weights = get<Array>(model["weights"]);
ASSERT_EQ(weights.size(), 17);
}
{
model = Json::Load({model_str.c_str(), model_str.size()});
model = model["model"];
auto weights = get<Array>(model["weights"]);
ASSERT_EQ(weights.size(), 17); // 16 + 1 (bias)
}
}
} // namespace gbm

View File

@ -64,7 +64,7 @@ TEST(GBTree, ChoosePredictor) {
}
ASSERT_TRUE(data.HostCanWrite());
dmlc::TemporaryDirectory tempdir;
const std::string fname = tempdir.path + "/model_para.bst";
const std::string fname = tempdir.path + "/model_param.bst";
{
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
@ -117,17 +117,15 @@ TEST(GBTree, Json_IO) {
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) };
Json model {Object()};
model["model"] = Object();
auto& j_model = model["model"];
gbm->SaveModel(&j_model);
gbm->SaveModel(&model);
std::stringstream ss;
Json::Dump(model, &ss);
std::string model_str;
Json::Dump(model, &model_str);
auto model_str = ss.str();
model = Json::Load({model_str.c_str(), model_str.size()});
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
auto loaded_model = Json::Load(StringView{model_str.c_str(), model_str.size()});
ASSERT_EQ(get<String>(loaded_model["name"]), "gbtree");
ASSERT_TRUE(IsA<Object>(loaded_model["model"]["gbtree_model_param"]));
}
TEST(Dart, Json_IO) {

View File

@ -143,7 +143,7 @@ TEST(Learner, Json_ModelIO) {
for (int32_t iter = 0; iter < kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
}
learner->SetAttr("bset_score", "15.2");
learner->SetAttr("best_score", "15.2");
Json out { Object() };
learner->SaveModel(&out);
@ -153,8 +153,8 @@ TEST(Learner, Json_ModelIO) {
learner->Configure();
learner->SaveModel(&new_in);
ASSERT_TRUE(IsA<Object>(out["Learner"]["attributes"]));
ASSERT_EQ(get<Object>(out["Learner"]["attributes"]).size(), 1);
ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1);
ASSERT_EQ(out, new_in);
}

View File

@ -13,7 +13,6 @@
#include "xgboost/json.h"
#include "../../../src/data/sparse_page_source.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../../../src/tree/updater_gpu_hist.cu"
#include "../../../src/tree/updater_gpu_common.cuh"
#include "../../../src/common/common.h"

View File

@ -213,12 +213,12 @@ class TestModels(unittest.TestCase):
with open('./model.json', 'r') as fd:
j_model = json.load(fd)
assert isinstance(j_model['Learner'], dict)
assert isinstance(j_model['learner'], dict)
bst = xgb.Booster(model_file='./model.json')
with open('./model.json', 'r') as fd:
j_model = json.load(fd)
assert isinstance(j_model['Learner'], dict)
assert isinstance(j_model['learner'], dict)
os.remove('model.json')