JSON configuration IO. (#5111)
* Add saving/loading JSON configuration. * Implement Python pickle interface with new IO routines. * Basic tests for training continuation.
This commit is contained in:
@@ -117,15 +117,28 @@ TEST(GBTree, Json_IO) {
|
||||
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) };
|
||||
|
||||
Json model {Object()};
|
||||
model["model"] = Object();
|
||||
auto& j_model = model["model"];
|
||||
|
||||
gbm->SaveModel(&model);
|
||||
model["config"] = Object();
|
||||
auto& j_param = model["config"];
|
||||
|
||||
gbm->SaveModel(&j_model);
|
||||
gbm->SaveConfig(&j_param);
|
||||
|
||||
std::string model_str;
|
||||
Json::Dump(model, &model_str);
|
||||
|
||||
auto loaded_model = Json::Load(StringView{model_str.c_str(), model_str.size()});
|
||||
ASSERT_EQ(get<String>(loaded_model["name"]), "gbtree");
|
||||
ASSERT_TRUE(IsA<Object>(loaded_model["model"]["gbtree_model_param"]));
|
||||
model = Json::Load({model_str.c_str(), model_str.size()});
|
||||
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
|
||||
|
||||
auto const& gbtree_model = model["model"]["model"];
|
||||
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1);
|
||||
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
|
||||
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1);
|
||||
|
||||
auto j_train_param = model["config"]["gbtree_train_param"];
|
||||
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");
|
||||
}
|
||||
|
||||
TEST(Dart, Json_IO) {
|
||||
@@ -145,20 +158,21 @@ TEST(Dart, Json_IO) {
|
||||
Json model {Object()};
|
||||
model["model"] = Object();
|
||||
auto& j_model = model["model"];
|
||||
model["parameters"] = Object();
|
||||
model["config"] = Object();
|
||||
|
||||
auto& j_param = model["config"];
|
||||
|
||||
gbm->SaveModel(&j_model);
|
||||
gbm->SaveConfig(&j_param);
|
||||
|
||||
std::string model_str;
|
||||
Json::Dump(model, &model_str);
|
||||
|
||||
model = Json::Load({model_str.c_str(), model_str.size()});
|
||||
|
||||
{
|
||||
auto const& gbtree = model["model"]["gbtree"];
|
||||
ASSERT_TRUE(IsA<Object>(gbtree));
|
||||
ASSERT_EQ(get<String>(model["model"]["name"]), "dart");
|
||||
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
|
||||
}
|
||||
ASSERT_EQ(get<String>(model["model"]["name"]), "dart") << model;
|
||||
ASSERT_EQ(get<String>(model["config"]["name"]), "dart");
|
||||
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
|
||||
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user