JSON configuration IO. (#5111)

* Add saving/loading JSON configuration.
* Implement Python pickle interface with new IO routines.
* Basic tests for training continuation.
This commit is contained in:
Jiaming Yuan
2019-12-15 17:31:53 +08:00
committed by GitHub
parent 5aa007d7b2
commit 3136185bc5
24 changed files with 761 additions and 390 deletions

View File

@@ -117,15 +117,28 @@ TEST(GBTree, Json_IO) {
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) };
Json model {Object()};
model["model"] = Object();
auto& j_model = model["model"];
gbm->SaveModel(&model);
model["config"] = Object();
auto& j_param = model["config"];
gbm->SaveModel(&j_model);
gbm->SaveConfig(&j_param);
std::string model_str;
Json::Dump(model, &model_str);
auto loaded_model = Json::Load(StringView{model_str.c_str(), model_str.size()});
ASSERT_EQ(get<String>(loaded_model["name"]), "gbtree");
ASSERT_TRUE(IsA<Object>(loaded_model["model"]["gbtree_model_param"]));
model = Json::Load({model_str.c_str(), model_str.size()});
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
auto const& gbtree_model = model["model"]["model"];
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1);
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1);
auto j_train_param = model["config"]["gbtree_train_param"];
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");
}
TEST(Dart, Json_IO) {
@@ -145,20 +158,21 @@ TEST(Dart, Json_IO) {
Json model {Object()};
model["model"] = Object();
auto& j_model = model["model"];
model["parameters"] = Object();
model["config"] = Object();
auto& j_param = model["config"];
gbm->SaveModel(&j_model);
gbm->SaveConfig(&j_param);
std::string model_str;
Json::Dump(model, &model_str);
model = Json::Load({model_str.c_str(), model_str.size()});
{
auto const& gbtree = model["model"]["gbtree"];
ASSERT_TRUE(IsA<Object>(gbtree));
ASSERT_EQ(get<String>(model["model"]["name"]), "dart");
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
}
ASSERT_EQ(get<String>(model["model"]["name"]), "dart") << model;
ASSERT_EQ(get<String>(model["config"]["name"]), "dart");
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
}
} // namespace xgboost