Add parameter to make node type configurable in plot tree (#3859)

* add parameters 'conditionNodeParams' and 'leafNodeParams' to function `to_graphviz` enable to configure node type
This commit is contained in:
Joey Gao 2018-11-16 12:29:37 +08:00 committed by Jiaming Yuan
parent 3a150742c7
commit 0cd326c1bc

View File

@ -16,7 +16,6 @@ def plot_importance(booster, ax=None, height=0.2,
xlabel='F score', ylabel='Features',
importance_type='weight', max_num_features=None,
grid=True, show_values=True, **kwargs):
"""Plot importance based on fitted trees.
Parameters
@ -124,17 +123,17 @@ _EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
def _parse_node(graph, text):
def _parse_node(graph, text, condition_node_params, leaf_node_params):
"""parse dumped node"""
match = _NODEPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), shape='circle')
graph.node(node, label=match.group(2), **condition_node_params)
return node
match = _LEAFPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), shape='box')
graph.node(node, label=match.group(2), **leaf_node_params)
return node
raise ValueError('Unable to parse node: {0}'.format(text))
@ -164,8 +163,8 @@ def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
yes_color='#0000FF', no_color='#FF0000', **kwargs):
yes_color='#0000FF', no_color='#FF0000',
condition_node_params=None, leaf_node_params=None, **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot the
returned graphiz instance. Otherwise, you should call .render() method
of the returned graphiz instance.
@ -184,6 +183,18 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
Edge color when meets the node condition.
no_color : str, default '#FF0000'
Edge color when doesn't meet the node condition.
condition_node_params : dict (optional)
condition node configuration,
{'shape':'box',
'style':'filled,rounded',
'fillcolor':'#78bceb'
}
leaf_node_params : dict (optional)
leaf node configuration
{'shape':'box',
'style':'filled',
'fillcolor':'#e48038'
}
kwargs :
Other keywords passed to graphviz graph_attr
@ -192,6 +203,11 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
ax : matplotlib Axes
"""
if condition_node_params is None:
condition_node_params = {}
if leaf_node_params is None:
leaf_node_params = {}
try:
from graphviz import Digraph
except ImportError:
@ -212,7 +228,9 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
for i, text in enumerate(tree):
if text[0].isdigit():
node = _parse_node(graph, text)
node = _parse_node(
graph, text, condition_node_params=condition_node_params,
leaf_node_params=leaf_node_params)
else:
if i == 0:
# 1st string must be node
@ -256,7 +274,8 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
if ax is None:
_, ax = plt.subplots(1, 1)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees,
rankdir=rankdir, **kwargs)
s = BytesIO()
s.write(g.pipe(format='png'))