python package tree plotting support fmap (#1856)

* to_graphviz and plot_tree support fmap

* [python-package] add model_plot docstring
This commit is contained in:
Ian 2016-12-13 21:36:17 +08:00 committed by Yuan (Terry) Tang
parent 49bdb5c97f
commit 167864da75

View File

@ -153,7 +153,7 @@ def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
raise ValueError('Unable to parse edge: {0}'.format(text)) raise ValueError('Unable to parse edge: {0}'.format(text))
def to_graphviz(booster, num_trees=0, rankdir='UT', def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
yes_color='#0000FF', no_color='#FF0000', **kwargs): yes_color='#0000FF', no_color='#FF0000', **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot the """Convert specified tree to graphviz instance. IPython can automatically plot the
@ -164,6 +164,8 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
---------- ----------
booster : Booster, XGBModel booster : Booster, XGBModel
Booster or XGBModel instance Booster or XGBModel instance
fmap: str (optional)
The name of feature map file
num_trees : int, default 0 num_trees : int, default 0
Specify the ordinal number of target tree Specify the ordinal number of target tree
rankdir : str, default "UT" rankdir : str, default "UT"
@ -191,7 +193,7 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
if isinstance(booster, XGBModel): if isinstance(booster, XGBModel):
booster = booster.booster() booster = booster.booster()
tree = booster.get_dump()[num_trees] tree = booster.get_dump(fmap=fmap)[num_trees]
tree = tree.split() tree = tree.split()
kwargs = kwargs.copy() kwargs = kwargs.copy()
@ -211,13 +213,15 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
return graph return graph
def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs): def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
"""Plot specified tree. """Plot specified tree.
Parameters Parameters
---------- ----------
booster : Booster, XGBModel booster : Booster, XGBModel
Booster or XGBModel instance Booster or XGBModel instance
fmap: str (optional)
The name of feature map file
num_trees : int, default 0 num_trees : int, default 0
Specify the ordinal number of target tree Specify the ordinal number of target tree
rankdir : str, default "UT" rankdir : str, default "UT"
@ -242,7 +246,7 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
if ax is None: if ax is None:
_, ax = plt.subplots(1, 1) _, ax = plt.subplots(1, 1)
g = to_graphviz(booster, num_trees=num_trees, rankdir=rankdir, **kwargs) g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs)
s = BytesIO() s = BytesIO()
s.write(g.pipe(format='png')) s.write(g.pipe(format='png'))