Model IO in JSON. (#5110)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include <xgboost/tree_model.h>
|
||||
#include "../helpers.h"
|
||||
#include "dmlc/filesystem.h"
|
||||
#include "xgboost/json_io.h"
|
||||
|
||||
namespace xgboost {
|
||||
// Manually construct tree in binary format
|
||||
@@ -77,7 +78,7 @@ TEST(Tree, Load) {
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(tmp_file.c_str(), "r"));
|
||||
|
||||
xgboost::RegTree tree;
|
||||
tree.LoadModel(fi.get());
|
||||
tree.Load(fi.get());
|
||||
EXPECT_EQ(tree.GetDepth(1), 1);
|
||||
EXPECT_EQ(tree[0].SplitCond(), 0.5f);
|
||||
EXPECT_EQ(tree[0].SplitIndex(), 5);
|
||||
@@ -218,4 +219,30 @@ TEST(Tree, DumpDot) {
|
||||
str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})");
|
||||
ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos);
|
||||
}
|
||||
|
||||
TEST(Tree, Json_IO) {
|
||||
RegTree tree;
|
||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
Json j_tree{Object()};
|
||||
tree.SaveModel(&j_tree);
|
||||
std::stringstream ss;
|
||||
Json::Dump(j_tree, &ss);
|
||||
|
||||
auto tparam = j_tree["tree_param"];
|
||||
ASSERT_EQ(get<String>(tparam["num_feature"]), "0");
|
||||
ASSERT_EQ(get<String>(tparam["num_nodes"]), "3");
|
||||
ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0");
|
||||
|
||||
ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3);
|
||||
|
||||
RegTree loaded_tree;
|
||||
loaded_tree.LoadModel(j_tree);
|
||||
ASSERT_EQ(loaded_tree.param.num_nodes, 3);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user