python package tree plotting support fmap (#1856)
* to_graphviz and plot_tree support fmap * [python-package] add model_plot docstring
This commit is contained in:
parent
49bdb5c97f
commit
167864da75
@ -153,7 +153,7 @@ def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
|
||||
raise ValueError('Unable to parse edge: {0}'.format(text))
|
||||
|
||||
|
||||
def to_graphviz(booster, num_trees=0, rankdir='UT',
|
||||
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
||||
yes_color='#0000FF', no_color='#FF0000', **kwargs):
|
||||
|
||||
"""Convert specified tree to graphviz instance. IPython can automatically plot the
|
||||
@ -164,6 +164,8 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
||||
----------
|
||||
booster : Booster, XGBModel
|
||||
Booster or XGBModel instance
|
||||
fmap: str (optional)
|
||||
The name of feature map file
|
||||
num_trees : int, default 0
|
||||
Specify the ordinal number of target tree
|
||||
rankdir : str, default "UT"
|
||||
@ -191,7 +193,7 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
||||
if isinstance(booster, XGBModel):
|
||||
booster = booster.booster()
|
||||
|
||||
tree = booster.get_dump()[num_trees]
|
||||
tree = booster.get_dump(fmap=fmap)[num_trees]
|
||||
tree = tree.split()
|
||||
|
||||
kwargs = kwargs.copy()
|
||||
@ -211,13 +213,15 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
||||
return graph
|
||||
|
||||
|
||||
def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
"""Plot specified tree.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster, XGBModel
|
||||
Booster or XGBModel instance
|
||||
fmap: str (optional)
|
||||
The name of feature map file
|
||||
num_trees : int, default 0
|
||||
Specify the ordinal number of target tree
|
||||
rankdir : str, default "UT"
|
||||
@ -242,7 +246,7 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 1)
|
||||
|
||||
g = to_graphviz(booster, num_trees=num_trees, rankdir=rankdir, **kwargs)
|
||||
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs)
|
||||
|
||||
s = BytesIO()
|
||||
s.write(g.pipe(format='png'))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user