Model IO in JSON. (#5110)

This commit is contained in:
Jiaming Yuan
2019-12-11 11:20:40 +08:00
committed by GitHub
parent c7cc657a4d
commit 208ab3b1ff
25 changed files with 667 additions and 165 deletions

View File

@@ -3,6 +3,7 @@
#include <xgboost/tree_model.h>
#include "../helpers.h"
#include "dmlc/filesystem.h"
#include "xgboost/json_io.h"
namespace xgboost {
// Manually construct tree in binary format
@@ -77,7 +78,7 @@ TEST(Tree, Load) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(tmp_file.c_str(), "r"));
xgboost::RegTree tree;
tree.LoadModel(fi.get());
tree.Load(fi.get());
EXPECT_EQ(tree.GetDepth(1), 1);
EXPECT_EQ(tree[0].SplitCond(), 0.5f);
EXPECT_EQ(tree[0].SplitIndex(), 5);
@@ -218,4 +219,30 @@ TEST(Tree, DumpDot) {
str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})");
ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos);
}
TEST(Tree, Json_IO) {
RegTree tree;
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
Json j_tree{Object()};
tree.SaveModel(&j_tree);
std::stringstream ss;
Json::Dump(j_tree, &ss);
auto tparam = j_tree["tree_param"];
ASSERT_EQ(get<String>(tparam["num_feature"]), "0");
ASSERT_EQ(get<String>(tparam["num_nodes"]), "3");
ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0");
ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3);
RegTree loaded_tree;
loaded_tree.LoadModel(j_tree);
ASSERT_EQ(loaded_tree.param.num_nodes, 3);
}
} // namespace xgboost