Implement tree model dump with code generator. (#4602)
* Implement tree model dump with a code generator. * Split up generators. * Implement graphviz generator. * Use pattern matching. * [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package. Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -101,4 +101,121 @@ TEST(Tree, AllocateNode) {
|
||||
ASSERT_TRUE(nodes.at(1).IsLeaf());
|
||||
ASSERT_TRUE(nodes.at(2).IsLeaf());
|
||||
}
|
||||
|
||||
RegTree ConstructTree() {
|
||||
RegTree tree;
|
||||
tree.ExpandNode(
|
||||
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
|
||||
/*default_left=*/true,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
auto left = tree[0].LeftChild();
|
||||
auto right = tree[0].RightChild();
|
||||
tree.ExpandNode(
|
||||
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
|
||||
/*default_left=*/false,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.ExpandNode(
|
||||
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
|
||||
/*default_left=*/false,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
return tree;
|
||||
}
|
||||
|
||||
TEST(Tree, DumpJson) {
|
||||
auto tree = ConstructTree();
|
||||
FeatureMap fmap;
|
||||
auto str = tree.DumpModel(fmap, true, "json");
|
||||
size_t n_leaves = 0;
|
||||
size_t iter = 0;
|
||||
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||
n_leaves++;
|
||||
}
|
||||
ASSERT_EQ(n_leaves, 4);
|
||||
|
||||
size_t n_conditions = 0;
|
||||
iter = 0;
|
||||
while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
|
||||
n_conditions++;
|
||||
}
|
||||
ASSERT_EQ(n_conditions, 3);
|
||||
|
||||
fmap.PushBack(0, "feat_0", "i");
|
||||
fmap.PushBack(1, "feat_1", "q");
|
||||
fmap.PushBack(2, "feat_2", "int");
|
||||
|
||||
str = tree.DumpModel(fmap, true, "json");
|
||||
ASSERT_NE(str.find(R"("split": "feat_0")"), std::string::npos);
|
||||
ASSERT_NE(str.find(R"("split": "feat_1")"), std::string::npos);
|
||||
ASSERT_NE(str.find(R"("split": "feat_2")"), std::string::npos);
|
||||
|
||||
str = tree.DumpModel(fmap, false, "json");
|
||||
ASSERT_EQ(str.find("cover"), std::string::npos);
|
||||
}
|
||||
|
||||
TEST(Tree, DumpText) {
|
||||
auto tree = ConstructTree();
|
||||
FeatureMap fmap;
|
||||
auto str = tree.DumpModel(fmap, true, "text");
|
||||
size_t n_leaves = 0;
|
||||
size_t iter = 0;
|
||||
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||
n_leaves++;
|
||||
}
|
||||
ASSERT_EQ(n_leaves, 4);
|
||||
|
||||
iter = 0;
|
||||
size_t n_conditions = 0;
|
||||
while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
|
||||
n_conditions++;
|
||||
}
|
||||
ASSERT_EQ(n_conditions, 3);
|
||||
|
||||
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
|
||||
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
|
||||
ASSERT_NE(str.find("[f2<2]"), std::string::npos);
|
||||
|
||||
fmap.PushBack(0, "feat_0", "i");
|
||||
fmap.PushBack(1, "feat_1", "q");
|
||||
fmap.PushBack(2, "feat_2", "int");
|
||||
|
||||
str = tree.DumpModel(fmap, true, "text");
|
||||
ASSERT_NE(str.find("[feat_0]"), std::string::npos);
|
||||
ASSERT_NE(str.find("[feat_1<1]"), std::string::npos);
|
||||
ASSERT_NE(str.find("[feat_2<2]"), std::string::npos);
|
||||
|
||||
str = tree.DumpModel(fmap, false, "text");
|
||||
ASSERT_EQ(str.find("cover"), std::string::npos);
|
||||
}
|
||||
|
||||
TEST(Tree, DumpDot) {
|
||||
auto tree = ConstructTree();
|
||||
FeatureMap fmap;
|
||||
auto str = tree.DumpModel(fmap, true, "dot");
|
||||
|
||||
size_t n_leaves = 0;
|
||||
size_t iter = 0;
|
||||
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||
n_leaves++;
|
||||
}
|
||||
ASSERT_EQ(n_leaves, 4);
|
||||
|
||||
size_t n_edges = 0;
|
||||
iter = 0;
|
||||
while ((iter = str.find("->", iter + 1)) != std::string::npos) {
|
||||
n_edges++;
|
||||
}
|
||||
ASSERT_EQ(n_edges, 6);
|
||||
|
||||
fmap.PushBack(0, "feat_0", "i");
|
||||
fmap.PushBack(1, "feat_1", "q");
|
||||
fmap.PushBack(2, "feat_2", "int");
|
||||
|
||||
str = tree.DumpModel(fmap, true, "dot");
|
||||
ASSERT_NE(str.find(R"("feat_0")"), std::string::npos);
|
||||
ASSERT_NE(str.find(R"(feat_1<1)"), std::string::npos);
|
||||
ASSERT_NE(str.find(R"(feat_2<2)"), std::string::npos);
|
||||
|
||||
str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})");
|
||||
ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -10,7 +10,7 @@ try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
from matplotlib.axes import Axes
|
||||
from graphviz import Digraph
|
||||
from graphviz import Source
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -57,7 +57,7 @@ class TestPlotting(unittest.TestCase):
|
||||
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
|
||||
|
||||
g = xgb.to_graphviz(bst2, num_trees=0)
|
||||
assert isinstance(g, Digraph)
|
||||
assert isinstance(g, Source)
|
||||
|
||||
ax = xgb.plot_tree(bst2, num_trees=0)
|
||||
assert isinstance(ax, Axes)
|
||||
|
||||
@@ -87,7 +87,6 @@ class TestSHAP(unittest.TestCase):
|
||||
r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)"
|
||||
r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)"
|
||||
for tree in model.get_dump(with_stats=True):
|
||||
|
||||
lines = list(tree.splitlines())
|
||||
trees.append([None for i in range(len(lines))])
|
||||
for line in lines:
|
||||
|
||||
@@ -352,7 +352,7 @@ def test_sklearn_plotting():
|
||||
matplotlib.use('Agg')
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from graphviz import Digraph
|
||||
from graphviz import Source
|
||||
|
||||
ax = xgb.plot_importance(classifier)
|
||||
assert isinstance(ax, Axes)
|
||||
@@ -362,7 +362,7 @@ def test_sklearn_plotting():
|
||||
assert len(ax.patches) == 4
|
||||
|
||||
g = xgb.to_graphviz(classifier, num_trees=0)
|
||||
assert isinstance(g, Digraph)
|
||||
assert isinstance(g, Source)
|
||||
|
||||
ax = xgb.plot_tree(classifier, num_trees=0)
|
||||
assert isinstance(ax, Axes)
|
||||
@@ -641,7 +641,8 @@ def test_XGBClassifier_resume():
|
||||
|
||||
X, Y = load_breast_cancer(return_X_y=True)
|
||||
|
||||
model1 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
|
||||
model1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
model1.fit(X, Y)
|
||||
|
||||
pred1 = model1.predict(X)
|
||||
@@ -649,7 +650,8 @@ def test_XGBClassifier_resume():
|
||||
|
||||
# file name of stored xgb model
|
||||
model1.save_model(model1_path)
|
||||
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
|
||||
model2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
model2.fit(X, Y, xgb_model=model1_path)
|
||||
|
||||
pred2 = model2.predict(X)
|
||||
@@ -660,7 +662,8 @@ def test_XGBClassifier_resume():
|
||||
|
||||
# file name of 'Booster' instance Xgb model
|
||||
model1.get_booster().save_model(model1_booster_path)
|
||||
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
|
||||
model2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
model2.fit(X, Y, xgb_model=model1_booster_path)
|
||||
|
||||
pred2 = model2.predict(X)
|
||||
|
||||
Reference in New Issue
Block a user