Implement tree model dump with code generator. (#4602)

* Implement tree model dump with a code generator.

* Split up generators.
* Implement graphviz generator.
* Use pattern matching.

* [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package.


Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2019-06-26 15:20:44 +08:00
committed by GitHub
parent fe2de6f415
commit 8bdf15120a
11 changed files with 802 additions and 264 deletions

View File

@@ -101,4 +101,121 @@ TEST(Tree, AllocateNode) {
ASSERT_TRUE(nodes.at(1).IsLeaf());
ASSERT_TRUE(nodes.at(2).IsLeaf());
}
RegTree ConstructTree() {
RegTree tree;
tree.ExpandNode(
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
/*default_left=*/true,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
return tree;
}
TEST(Tree, DumpJson) {
auto tree = ConstructTree();
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, "json");
size_t n_leaves = 0;
size_t iter = 0;
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
size_t n_conditions = 0;
iter = 0;
while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
n_conditions++;
}
ASSERT_EQ(n_conditions, 3);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, "json");
ASSERT_NE(str.find(R"("split": "feat_0")"), std::string::npos);
ASSERT_NE(str.find(R"("split": "feat_1")"), std::string::npos);
ASSERT_NE(str.find(R"("split": "feat_2")"), std::string::npos);
str = tree.DumpModel(fmap, false, "json");
ASSERT_EQ(str.find("cover"), std::string::npos);
}
TEST(Tree, DumpText) {
auto tree = ConstructTree();
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, "text");
size_t n_leaves = 0;
size_t iter = 0;
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
iter = 0;
size_t n_conditions = 0;
while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
n_conditions++;
}
ASSERT_EQ(n_conditions, 3);
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
ASSERT_NE(str.find("[f2<2]"), std::string::npos);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, "text");
ASSERT_NE(str.find("[feat_0]"), std::string::npos);
ASSERT_NE(str.find("[feat_1<1]"), std::string::npos);
ASSERT_NE(str.find("[feat_2<2]"), std::string::npos);
str = tree.DumpModel(fmap, false, "text");
ASSERT_EQ(str.find("cover"), std::string::npos);
}
TEST(Tree, DumpDot) {
auto tree = ConstructTree();
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, "dot");
size_t n_leaves = 0;
size_t iter = 0;
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
size_t n_edges = 0;
iter = 0;
while ((iter = str.find("->", iter + 1)) != std::string::npos) {
n_edges++;
}
ASSERT_EQ(n_edges, 6);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, "dot");
ASSERT_NE(str.find(R"("feat_0")"), std::string::npos);
ASSERT_NE(str.find(R"(feat_1<1)"), std::string::npos);
ASSERT_NE(str.find(R"(feat_2<2)"), std::string::npos);
str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})");
ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos);
}
} // namespace xgboost