Add parameter to make node type configurable in plot tree (#3859)
* add parameters 'conditionNodeParams' and 'leafNodeParams' to function `to_graphviz` enable to configure node type
This commit is contained in:
parent
3a150742c7
commit
0cd326c1bc
@ -16,7 +16,6 @@ def plot_importance(booster, ax=None, height=0.2,
|
||||
xlabel='F score', ylabel='Features',
|
||||
importance_type='weight', max_num_features=None,
|
||||
grid=True, show_values=True, **kwargs):
|
||||
|
||||
"""Plot importance based on fitted trees.
|
||||
|
||||
Parameters
|
||||
@ -124,17 +123,17 @@ _EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
|
||||
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
|
||||
|
||||
|
||||
def _parse_node(graph, text):
|
||||
def _parse_node(graph, text, condition_node_params, leaf_node_params):
|
||||
"""parse dumped node"""
|
||||
match = _NODEPAT.match(text)
|
||||
if match is not None:
|
||||
node = match.group(1)
|
||||
graph.node(node, label=match.group(2), shape='circle')
|
||||
graph.node(node, label=match.group(2), **condition_node_params)
|
||||
return node
|
||||
match = _LEAFPAT.match(text)
|
||||
if match is not None:
|
||||
node = match.group(1)
|
||||
graph.node(node, label=match.group(2), shape='box')
|
||||
graph.node(node, label=match.group(2), **leaf_node_params)
|
||||
return node
|
||||
raise ValueError('Unable to parse node: {0}'.format(text))
|
||||
|
||||
@ -164,8 +163,8 @@ def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
|
||||
|
||||
|
||||
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
||||
yes_color='#0000FF', no_color='#FF0000', **kwargs):
|
||||
|
||||
yes_color='#0000FF', no_color='#FF0000',
|
||||
condition_node_params=None, leaf_node_params=None, **kwargs):
|
||||
"""Convert specified tree to graphviz instance. IPython can automatically plot the
|
||||
returned graphiz instance. Otherwise, you should call .render() method
|
||||
of the returned graphiz instance.
|
||||
@ -184,6 +183,18 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
||||
Edge color when meets the node condition.
|
||||
no_color : str, default '#FF0000'
|
||||
Edge color when doesn't meet the node condition.
|
||||
condition_node_params : dict (optional)
|
||||
condition node configuration,
|
||||
{'shape':'box',
|
||||
'style':'filled,rounded',
|
||||
'fillcolor':'#78bceb'
|
||||
}
|
||||
leaf_node_params : dict (optional)
|
||||
leaf node configuration
|
||||
{'shape':'box',
|
||||
'style':'filled',
|
||||
'fillcolor':'#e48038'
|
||||
}
|
||||
kwargs :
|
||||
Other keywords passed to graphviz graph_attr
|
||||
|
||||
@ -192,6 +203,11 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
||||
ax : matplotlib Axes
|
||||
"""
|
||||
|
||||
if condition_node_params is None:
|
||||
condition_node_params = {}
|
||||
if leaf_node_params is None:
|
||||
leaf_node_params = {}
|
||||
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
@ -212,7 +228,9 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
||||
|
||||
for i, text in enumerate(tree):
|
||||
if text[0].isdigit():
|
||||
node = _parse_node(graph, text)
|
||||
node = _parse_node(
|
||||
graph, text, condition_node_params=condition_node_params,
|
||||
leaf_node_params=leaf_node_params)
|
||||
else:
|
||||
if i == 0:
|
||||
# 1st string must be node
|
||||
@ -256,7 +274,8 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 1)
|
||||
|
||||
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs)
|
||||
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees,
|
||||
rankdir=rankdir, **kwargs)
|
||||
|
||||
s = BytesIO()
|
||||
s.write(g.pipe(format='png'))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user