Small refinements for JSON model. (#5112)

* Naming consistency.

* Remove duplicated test.
This commit is contained in:
Jiaming Yuan
2019-12-11 19:49:01 +08:00
committed by GitHub
parent 208ab3b1ff
commit ad4a1c732c
7 changed files with 19 additions and 30 deletions

View File

@@ -35,22 +35,14 @@ TEST(GBLinear, Json_IO) {
std::string model_str;
Json::Dump(model, &model_str);
model = Json::Load({model_str.c_str(), model_str.size()});
model = Json::Load(StringView{model_str.c_str(), model_str.size()});
ASSERT_TRUE(IsA<Object>(model));
model = model["model"];
{
model = model["model"];
auto weights = get<Array>(model["weights"]);
ASSERT_EQ(weights.size(), 17);
}
{
model = Json::Load({model_str.c_str(), model_str.size()});
model = model["model"];
auto weights = get<Array>(model["weights"]);
ASSERT_EQ(weights.size(), 17); // 16 + 1 (bias)
}
}
} // namespace gbm

View File

@@ -64,7 +64,7 @@ TEST(GBTree, ChoosePredictor) {
}
ASSERT_TRUE(data.HostCanWrite());
dmlc::TemporaryDirectory tempdir;
const std::string fname = tempdir.path + "/model_para.bst";
const std::string fname = tempdir.path + "/model_param.bst";
{
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
@@ -117,17 +117,15 @@ TEST(GBTree, Json_IO) {
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) };
Json model {Object()};
model["model"] = Object();
auto& j_model = model["model"];
gbm->SaveModel(&j_model);
gbm->SaveModel(&model);
std::stringstream ss;
Json::Dump(model, &ss);
std::string model_str;
Json::Dump(model, &model_str);
auto model_str = ss.str();
model = Json::Load({model_str.c_str(), model_str.size()});
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
auto loaded_model = Json::Load(StringView{model_str.c_str(), model_str.size()});
ASSERT_EQ(get<String>(loaded_model["name"]), "gbtree");
ASSERT_TRUE(IsA<Object>(loaded_model["model"]["gbtree_model_param"]));
}
TEST(Dart, Json_IO) {

View File

@@ -143,7 +143,7 @@ TEST(Learner, Json_ModelIO) {
for (int32_t iter = 0; iter < kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
}
learner->SetAttr("bset_score", "15.2");
learner->SetAttr("best_score", "15.2");
Json out { Object() };
learner->SaveModel(&out);
@@ -153,8 +153,8 @@ TEST(Learner, Json_ModelIO) {
learner->Configure();
learner->SaveModel(&new_in);
ASSERT_TRUE(IsA<Object>(out["Learner"]["attributes"]));
ASSERT_EQ(get<Object>(out["Learner"]["attributes"]).size(), 1);
ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1);
ASSERT_EQ(out, new_in);
}

View File

@@ -13,7 +13,6 @@
#include "xgboost/json.h"
#include "../../../src/data/sparse_page_source.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../../../src/tree/updater_gpu_hist.cu"
#include "../../../src/tree/updater_gpu_common.cuh"
#include "../../../src/common/common.h"