Implement tree model dump with code generator. (#4602)

* Implement tree model dump with a code generator.

* Split up generators.
* Implement graphviz generator.
* Use pattern matching.

* [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package.


Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2019-06-26 15:20:44 +08:00
committed by GitHub
parent fe2de6f415
commit 8bdf15120a
11 changed files with 802 additions and 264 deletions

View File

@@ -10,7 +10,7 @@ try:
import matplotlib
matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Digraph
from graphviz import Source
except ImportError:
pass
@@ -57,7 +57,7 @@ class TestPlotting(unittest.TestCase):
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Digraph)
assert isinstance(g, Source)
ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes)

View File

@@ -87,7 +87,6 @@ class TestSHAP(unittest.TestCase):
r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)"
r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)"
for tree in model.get_dump(with_stats=True):
lines = list(tree.splitlines())
trees.append([None for i in range(len(lines))])
for line in lines:

View File

@@ -352,7 +352,7 @@ def test_sklearn_plotting():
matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Digraph
from graphviz import Source
ax = xgb.plot_importance(classifier)
assert isinstance(ax, Axes)
@@ -362,7 +362,7 @@ def test_sklearn_plotting():
assert len(ax.patches) == 4
g = xgb.to_graphviz(classifier, num_trees=0)
assert isinstance(g, Digraph)
assert isinstance(g, Source)
ax = xgb.plot_tree(classifier, num_trees=0)
assert isinstance(ax, Axes)
@@ -641,7 +641,8 @@ def test_XGBClassifier_resume():
X, Y = load_breast_cancer(return_X_y=True)
model1 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
model1 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model1.fit(X, Y)
pred1 = model1.predict(X)
@@ -649,7 +650,8 @@ def test_XGBClassifier_resume():
# file name of stored xgb model
model1.save_model(model1_path)
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
model2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_path)
pred2 = model2.predict(X)
@@ -660,7 +662,8 @@ def test_XGBClassifier_resume():
# file name of 'Booster' instance Xgb model
model1.get_booster().save_model(model1_booster_path)
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
model2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_booster_path)
pred2 = model2.predict(X)