Small refinements for JSON model. (#5112)
* Naming consistency. * Remove duplicated test.
This commit is contained in:
parent
208ab3b1ff
commit
ad4a1c732c
@ -40,7 +40,7 @@ void GBTreeModel::Load(dmlc::Stream* fi) {
|
|||||||
void GBTreeModel::SaveModel(Json* p_out) const {
|
void GBTreeModel::SaveModel(Json* p_out) const {
|
||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
||||||
out["model_param"] = toJson(param);
|
out["gbtree_model_param"] = toJson(param);
|
||||||
std::vector<Json> trees_json;
|
std::vector<Json> trees_json;
|
||||||
size_t t = 0;
|
size_t t = 0;
|
||||||
for (auto const& tree : trees) {
|
for (auto const& tree : trees) {
|
||||||
@ -61,7 +61,7 @@ void GBTreeModel::SaveModel(Json* p_out) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GBTreeModel::LoadModel(Json const& in) {
|
void GBTreeModel::LoadModel(Json const& in) {
|
||||||
fromJson(in["model_param"], ¶m);
|
fromJson(in["gbtree_model_param"], ¶m);
|
||||||
|
|
||||||
trees.clear();
|
trees.clear();
|
||||||
trees_to_update.clear();
|
trees_to_update.clear();
|
||||||
|
|||||||
@ -269,7 +269,7 @@ class LearnerImpl : public Learner {
|
|||||||
void LoadModel(Json const& in) override {
|
void LoadModel(Json const& in) override {
|
||||||
CHECK(IsA<Object>(in));
|
CHECK(IsA<Object>(in));
|
||||||
Version::Load(in, false);
|
Version::Load(in, false);
|
||||||
auto const& learner = get<Object>(in["Learner"]);
|
auto const& learner = get<Object>(in["learner"]);
|
||||||
mparam_.FromJson(learner.at("learner_model_param"));
|
mparam_.FromJson(learner.at("learner_model_param"));
|
||||||
|
|
||||||
auto const& objective_fn = learner.at("objective");
|
auto const& objective_fn = learner.at("objective");
|
||||||
@ -305,8 +305,8 @@ class LearnerImpl : public Learner {
|
|||||||
Version::Save(p_out);
|
Version::Save(p_out);
|
||||||
Json& out { *p_out };
|
Json& out { *p_out };
|
||||||
|
|
||||||
out["Learner"] = Object();
|
out["learner"] = Object();
|
||||||
auto& learner = out["Learner"];
|
auto& learner = out["learner"];
|
||||||
|
|
||||||
learner["learner_model_param"] = mparam_.ToJson();
|
learner["learner_model_param"] = mparam_.ToJson();
|
||||||
learner["gradient_booster"] = Object();
|
learner["gradient_booster"] = Object();
|
||||||
|
|||||||
@ -35,22 +35,14 @@ TEST(GBLinear, Json_IO) {
|
|||||||
std::string model_str;
|
std::string model_str;
|
||||||
Json::Dump(model, &model_str);
|
Json::Dump(model, &model_str);
|
||||||
|
|
||||||
model = Json::Load({model_str.c_str(), model_str.size()});
|
model = Json::Load(StringView{model_str.c_str(), model_str.size()});
|
||||||
ASSERT_TRUE(IsA<Object>(model));
|
ASSERT_TRUE(IsA<Object>(model));
|
||||||
model = model["model"];
|
|
||||||
|
|
||||||
{
|
{
|
||||||
|
model = model["model"];
|
||||||
auto weights = get<Array>(model["weights"]);
|
auto weights = get<Array>(model["weights"]);
|
||||||
ASSERT_EQ(weights.size(), 17);
|
ASSERT_EQ(weights.size(), 17);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
|
||||||
model = Json::Load({model_str.c_str(), model_str.size()});
|
|
||||||
model = model["model"];
|
|
||||||
auto weights = get<Array>(model["weights"]);
|
|
||||||
ASSERT_EQ(weights.size(), 17); // 16 + 1 (bias)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gbm
|
} // namespace gbm
|
||||||
|
|||||||
@ -64,7 +64,7 @@ TEST(GBTree, ChoosePredictor) {
|
|||||||
}
|
}
|
||||||
ASSERT_TRUE(data.HostCanWrite());
|
ASSERT_TRUE(data.HostCanWrite());
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
const std::string fname = tempdir.path + "/model_para.bst";
|
const std::string fname = tempdir.path + "/model_param.bst";
|
||||||
|
|
||||||
{
|
{
|
||||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||||
@ -117,17 +117,15 @@ TEST(GBTree, Json_IO) {
|
|||||||
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) };
|
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) };
|
||||||
|
|
||||||
Json model {Object()};
|
Json model {Object()};
|
||||||
model["model"] = Object();
|
|
||||||
auto& j_model = model["model"];
|
|
||||||
|
|
||||||
gbm->SaveModel(&j_model);
|
gbm->SaveModel(&model);
|
||||||
|
|
||||||
std::stringstream ss;
|
std::string model_str;
|
||||||
Json::Dump(model, &ss);
|
Json::Dump(model, &model_str);
|
||||||
|
|
||||||
auto model_str = ss.str();
|
auto loaded_model = Json::Load(StringView{model_str.c_str(), model_str.size()});
|
||||||
model = Json::Load({model_str.c_str(), model_str.size()});
|
ASSERT_EQ(get<String>(loaded_model["name"]), "gbtree");
|
||||||
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
|
ASSERT_TRUE(IsA<Object>(loaded_model["model"]["gbtree_model_param"]));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Dart, Json_IO) {
|
TEST(Dart, Json_IO) {
|
||||||
|
|||||||
@ -143,7 +143,7 @@ TEST(Learner, Json_ModelIO) {
|
|||||||
for (int32_t iter = 0; iter < kIters; ++iter) {
|
for (int32_t iter = 0; iter < kIters; ++iter) {
|
||||||
learner->UpdateOneIter(iter, p_dmat.get());
|
learner->UpdateOneIter(iter, p_dmat.get());
|
||||||
}
|
}
|
||||||
learner->SetAttr("bset_score", "15.2");
|
learner->SetAttr("best_score", "15.2");
|
||||||
|
|
||||||
Json out { Object() };
|
Json out { Object() };
|
||||||
learner->SaveModel(&out);
|
learner->SaveModel(&out);
|
||||||
@ -153,8 +153,8 @@ TEST(Learner, Json_ModelIO) {
|
|||||||
learner->Configure();
|
learner->Configure();
|
||||||
learner->SaveModel(&new_in);
|
learner->SaveModel(&new_in);
|
||||||
|
|
||||||
ASSERT_TRUE(IsA<Object>(out["Learner"]["attributes"]));
|
ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
|
||||||
ASSERT_EQ(get<Object>(out["Learner"]["attributes"]).size(), 1);
|
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1);
|
||||||
ASSERT_EQ(out, new_in);
|
ASSERT_EQ(out, new_in);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,6 @@
|
|||||||
|
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "../../../src/data/sparse_page_source.h"
|
#include "../../../src/data/sparse_page_source.h"
|
||||||
#include "../../../src/gbm/gbtree_model.h"
|
|
||||||
#include "../../../src/tree/updater_gpu_hist.cu"
|
#include "../../../src/tree/updater_gpu_hist.cu"
|
||||||
#include "../../../src/tree/updater_gpu_common.cuh"
|
#include "../../../src/tree/updater_gpu_common.cuh"
|
||||||
#include "../../../src/common/common.h"
|
#include "../../../src/common/common.h"
|
||||||
|
|||||||
@ -213,12 +213,12 @@ class TestModels(unittest.TestCase):
|
|||||||
|
|
||||||
with open('./model.json', 'r') as fd:
|
with open('./model.json', 'r') as fd:
|
||||||
j_model = json.load(fd)
|
j_model = json.load(fd)
|
||||||
assert isinstance(j_model['Learner'], dict)
|
assert isinstance(j_model['learner'], dict)
|
||||||
|
|
||||||
bst = xgb.Booster(model_file='./model.json')
|
bst = xgb.Booster(model_file='./model.json')
|
||||||
|
|
||||||
with open('./model.json', 'r') as fd:
|
with open('./model.json', 'r') as fd:
|
||||||
j_model = json.load(fd)
|
j_model = json.load(fd)
|
||||||
assert isinstance(j_model['Learner'], dict)
|
assert isinstance(j_model['learner'], dict)
|
||||||
|
|
||||||
os.remove('model.json')
|
os.remove('model.json')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user