python package tree plotting support fmap (#1856)
* to_graphviz and plot_tree support fmap * [python-package] add model_plot docstring
This commit is contained in:
parent
49bdb5c97f
commit
167864da75
@ -153,7 +153,7 @@ def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
|
|||||||
raise ValueError('Unable to parse edge: {0}'.format(text))
|
raise ValueError('Unable to parse edge: {0}'.format(text))
|
||||||
|
|
||||||
|
|
||||||
def to_graphviz(booster, num_trees=0, rankdir='UT',
|
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
||||||
yes_color='#0000FF', no_color='#FF0000', **kwargs):
|
yes_color='#0000FF', no_color='#FF0000', **kwargs):
|
||||||
|
|
||||||
"""Convert specified tree to graphviz instance. IPython can automatically plot the
|
"""Convert specified tree to graphviz instance. IPython can automatically plot the
|
||||||
@ -164,6 +164,8 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
|||||||
----------
|
----------
|
||||||
booster : Booster, XGBModel
|
booster : Booster, XGBModel
|
||||||
Booster or XGBModel instance
|
Booster or XGBModel instance
|
||||||
|
fmap: str (optional)
|
||||||
|
The name of feature map file
|
||||||
num_trees : int, default 0
|
num_trees : int, default 0
|
||||||
Specify the ordinal number of target tree
|
Specify the ordinal number of target tree
|
||||||
rankdir : str, default "UT"
|
rankdir : str, default "UT"
|
||||||
@ -191,7 +193,7 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
|||||||
if isinstance(booster, XGBModel):
|
if isinstance(booster, XGBModel):
|
||||||
booster = booster.booster()
|
booster = booster.booster()
|
||||||
|
|
||||||
tree = booster.get_dump()[num_trees]
|
tree = booster.get_dump(fmap=fmap)[num_trees]
|
||||||
tree = tree.split()
|
tree = tree.split()
|
||||||
|
|
||||||
kwargs = kwargs.copy()
|
kwargs = kwargs.copy()
|
||||||
@ -211,13 +213,15 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
|||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
|
def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||||
"""Plot specified tree.
|
"""Plot specified tree.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
booster : Booster, XGBModel
|
booster : Booster, XGBModel
|
||||||
Booster or XGBModel instance
|
Booster or XGBModel instance
|
||||||
|
fmap: str (optional)
|
||||||
|
The name of feature map file
|
||||||
num_trees : int, default 0
|
num_trees : int, default 0
|
||||||
Specify the ordinal number of target tree
|
Specify the ordinal number of target tree
|
||||||
rankdir : str, default "UT"
|
rankdir : str, default "UT"
|
||||||
@ -242,7 +246,7 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
|
|||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(1, 1)
|
_, ax = plt.subplots(1, 1)
|
||||||
|
|
||||||
g = to_graphviz(booster, num_trees=num_trees, rankdir=rankdir, **kwargs)
|
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs)
|
||||||
|
|
||||||
s = BytesIO()
|
s = BytesIO()
|
||||||
s.write(g.pipe(format='png'))
|
s.write(g.pipe(format='png'))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user