Add parameter to make node type configurable in plot tree (#3859)

* add parameters 'conditionNodeParams' and 'leafNodeParams' to function `to_graphviz` enable to configure node type
This commit is contained in:
Joey Gao 2018-11-16 12:29:37 +08:00 committed by Jiaming Yuan
parent 3a150742c7
commit 0cd326c1bc

View File

@ -16,7 +16,6 @@ def plot_importance(booster, ax=None, height=0.2,
xlabel='F score', ylabel='Features', xlabel='F score', ylabel='Features',
importance_type='weight', max_num_features=None, importance_type='weight', max_num_features=None,
grid=True, show_values=True, **kwargs): grid=True, show_values=True, **kwargs):
"""Plot importance based on fitted trees. """Plot importance based on fitted trees.
Parameters Parameters
@ -124,17 +123,17 @@ _EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)') _EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
def _parse_node(graph, text): def _parse_node(graph, text, condition_node_params, leaf_node_params):
"""parse dumped node""" """parse dumped node"""
match = _NODEPAT.match(text) match = _NODEPAT.match(text)
if match is not None: if match is not None:
node = match.group(1) node = match.group(1)
graph.node(node, label=match.group(2), shape='circle') graph.node(node, label=match.group(2), **condition_node_params)
return node return node
match = _LEAFPAT.match(text) match = _LEAFPAT.match(text)
if match is not None: if match is not None:
node = match.group(1) node = match.group(1)
graph.node(node, label=match.group(2), shape='box') graph.node(node, label=match.group(2), **leaf_node_params)
return node return node
raise ValueError('Unable to parse node: {0}'.format(text)) raise ValueError('Unable to parse node: {0}'.format(text))
@ -164,8 +163,8 @@ def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT', def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
yes_color='#0000FF', no_color='#FF0000', **kwargs): yes_color='#0000FF', no_color='#FF0000',
condition_node_params=None, leaf_node_params=None, **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot the """Convert specified tree to graphviz instance. IPython can automatically plot the
returned graphiz instance. Otherwise, you should call .render() method returned graphiz instance. Otherwise, you should call .render() method
of the returned graphiz instance. of the returned graphiz instance.
@ -184,6 +183,18 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
Edge color when meets the node condition. Edge color when meets the node condition.
no_color : str, default '#FF0000' no_color : str, default '#FF0000'
Edge color when doesn't meet the node condition. Edge color when doesn't meet the node condition.
condition_node_params : dict (optional)
condition node configuration,
{'shape':'box',
'style':'filled,rounded',
'fillcolor':'#78bceb'
}
leaf_node_params : dict (optional)
leaf node configuration
{'shape':'box',
'style':'filled',
'fillcolor':'#e48038'
}
kwargs : kwargs :
Other keywords passed to graphviz graph_attr Other keywords passed to graphviz graph_attr
@ -192,6 +203,11 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
ax : matplotlib Axes ax : matplotlib Axes
""" """
if condition_node_params is None:
condition_node_params = {}
if leaf_node_params is None:
leaf_node_params = {}
try: try:
from graphviz import Digraph from graphviz import Digraph
except ImportError: except ImportError:
@ -212,7 +228,9 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
for i, text in enumerate(tree): for i, text in enumerate(tree):
if text[0].isdigit(): if text[0].isdigit():
node = _parse_node(graph, text) node = _parse_node(
graph, text, condition_node_params=condition_node_params,
leaf_node_params=leaf_node_params)
else: else:
if i == 0: if i == 0:
# 1st string must be node # 1st string must be node
@ -256,7 +274,8 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
if ax is None: if ax is None:
_, ax = plt.subplots(1, 1) _, ax = plt.subplots(1, 1)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs) g = to_graphviz(booster, fmap=fmap, num_trees=num_trees,
rankdir=rankdir, **kwargs)
s = BytesIO() s = BytesIO()
s.write(g.pipe(format='png')) s.write(g.pipe(format='png'))