Small refinements for JSON model. (#5112)
* Naming consistency. * Remove duplicated test.
This commit is contained in:
parent
208ab3b1ff
commit
ad4a1c732c
@ -40,7 +40,7 @@ void GBTreeModel::Load(dmlc::Stream* fi) {
|
||||
void GBTreeModel::SaveModel(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
||||
out["model_param"] = toJson(param);
|
||||
out["gbtree_model_param"] = toJson(param);
|
||||
std::vector<Json> trees_json;
|
||||
size_t t = 0;
|
||||
for (auto const& tree : trees) {
|
||||
@ -61,7 +61,7 @@ void GBTreeModel::SaveModel(Json* p_out) const {
|
||||
}
|
||||
|
||||
void GBTreeModel::LoadModel(Json const& in) {
|
||||
fromJson(in["model_param"], ¶m);
|
||||
fromJson(in["gbtree_model_param"], ¶m);
|
||||
|
||||
trees.clear();
|
||||
trees_to_update.clear();
|
||||
|
||||
@ -269,7 +269,7 @@ class LearnerImpl : public Learner {
|
||||
void LoadModel(Json const& in) override {
|
||||
CHECK(IsA<Object>(in));
|
||||
Version::Load(in, false);
|
||||
auto const& learner = get<Object>(in["Learner"]);
|
||||
auto const& learner = get<Object>(in["learner"]);
|
||||
mparam_.FromJson(learner.at("learner_model_param"));
|
||||
|
||||
auto const& objective_fn = learner.at("objective");
|
||||
@ -305,8 +305,8 @@ class LearnerImpl : public Learner {
|
||||
Version::Save(p_out);
|
||||
Json& out { *p_out };
|
||||
|
||||
out["Learner"] = Object();
|
||||
auto& learner = out["Learner"];
|
||||
out["learner"] = Object();
|
||||
auto& learner = out["learner"];
|
||||
|
||||
learner["learner_model_param"] = mparam_.ToJson();
|
||||
learner["gradient_booster"] = Object();
|
||||
|
||||
@ -35,22 +35,14 @@ TEST(GBLinear, Json_IO) {
|
||||
std::string model_str;
|
||||
Json::Dump(model, &model_str);
|
||||
|
||||
model = Json::Load({model_str.c_str(), model_str.size()});
|
||||
model = Json::Load(StringView{model_str.c_str(), model_str.size()});
|
||||
ASSERT_TRUE(IsA<Object>(model));
|
||||
model = model["model"];
|
||||
|
||||
{
|
||||
model = model["model"];
|
||||
auto weights = get<Array>(model["weights"]);
|
||||
ASSERT_EQ(weights.size(), 17);
|
||||
}
|
||||
|
||||
{
|
||||
model = Json::Load({model_str.c_str(), model_str.size()});
|
||||
model = model["model"];
|
||||
auto weights = get<Array>(model["weights"]);
|
||||
ASSERT_EQ(weights.size(), 17); // 16 + 1 (bias)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace gbm
|
||||
|
||||
@ -64,7 +64,7 @@ TEST(GBTree, ChoosePredictor) {
|
||||
}
|
||||
ASSERT_TRUE(data.HostCanWrite());
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string fname = tempdir.path + "/model_para.bst";
|
||||
const std::string fname = tempdir.path + "/model_param.bst";
|
||||
|
||||
{
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||
@ -117,17 +117,15 @@ TEST(GBTree, Json_IO) {
|
||||
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) };
|
||||
|
||||
Json model {Object()};
|
||||
model["model"] = Object();
|
||||
auto& j_model = model["model"];
|
||||
|
||||
gbm->SaveModel(&j_model);
|
||||
gbm->SaveModel(&model);
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(model, &ss);
|
||||
std::string model_str;
|
||||
Json::Dump(model, &model_str);
|
||||
|
||||
auto model_str = ss.str();
|
||||
model = Json::Load({model_str.c_str(), model_str.size()});
|
||||
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
|
||||
auto loaded_model = Json::Load(StringView{model_str.c_str(), model_str.size()});
|
||||
ASSERT_EQ(get<String>(loaded_model["name"]), "gbtree");
|
||||
ASSERT_TRUE(IsA<Object>(loaded_model["model"]["gbtree_model_param"]));
|
||||
}
|
||||
|
||||
TEST(Dart, Json_IO) {
|
||||
|
||||
@ -143,7 +143,7 @@ TEST(Learner, Json_ModelIO) {
|
||||
for (int32_t iter = 0; iter < kIters; ++iter) {
|
||||
learner->UpdateOneIter(iter, p_dmat.get());
|
||||
}
|
||||
learner->SetAttr("bset_score", "15.2");
|
||||
learner->SetAttr("best_score", "15.2");
|
||||
|
||||
Json out { Object() };
|
||||
learner->SaveModel(&out);
|
||||
@ -153,8 +153,8 @@ TEST(Learner, Json_ModelIO) {
|
||||
learner->Configure();
|
||||
learner->SaveModel(&new_in);
|
||||
|
||||
ASSERT_TRUE(IsA<Object>(out["Learner"]["attributes"]));
|
||||
ASSERT_EQ(get<Object>(out["Learner"]["attributes"]).size(), 1);
|
||||
ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
|
||||
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1);
|
||||
ASSERT_EQ(out, new_in);
|
||||
}
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "../../../src/data/sparse_page_source.h"
|
||||
#include "../../../src/gbm/gbtree_model.h"
|
||||
#include "../../../src/tree/updater_gpu_hist.cu"
|
||||
#include "../../../src/tree/updater_gpu_common.cuh"
|
||||
#include "../../../src/common/common.h"
|
||||
|
||||
@ -213,12 +213,12 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
with open('./model.json', 'r') as fd:
|
||||
j_model = json.load(fd)
|
||||
assert isinstance(j_model['Learner'], dict)
|
||||
assert isinstance(j_model['learner'], dict)
|
||||
|
||||
bst = xgb.Booster(model_file='./model.json')
|
||||
|
||||
with open('./model.json', 'r') as fd:
|
||||
j_model = json.load(fd)
|
||||
assert isinstance(j_model['Learner'], dict)
|
||||
assert isinstance(j_model['learner'], dict)
|
||||
|
||||
os.remove('model.json')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user