Add parameter to make node type configurable in plot tree (#3859)
* add parameters 'conditionNodeParams' and 'leafNodeParams' to function `to_graphviz` enable to configure node type
This commit is contained in:
parent
3a150742c7
commit
0cd326c1bc
@ -16,7 +16,6 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||||||
xlabel='F score', ylabel='Features',
|
xlabel='F score', ylabel='Features',
|
||||||
importance_type='weight', max_num_features=None,
|
importance_type='weight', max_num_features=None,
|
||||||
grid=True, show_values=True, **kwargs):
|
grid=True, show_values=True, **kwargs):
|
||||||
|
|
||||||
"""Plot importance based on fitted trees.
|
"""Plot importance based on fitted trees.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -124,17 +123,17 @@ _EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
|
|||||||
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
|
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
|
||||||
|
|
||||||
|
|
||||||
def _parse_node(graph, text):
|
def _parse_node(graph, text, condition_node_params, leaf_node_params):
|
||||||
"""parse dumped node"""
|
"""parse dumped node"""
|
||||||
match = _NODEPAT.match(text)
|
match = _NODEPAT.match(text)
|
||||||
if match is not None:
|
if match is not None:
|
||||||
node = match.group(1)
|
node = match.group(1)
|
||||||
graph.node(node, label=match.group(2), shape='circle')
|
graph.node(node, label=match.group(2), **condition_node_params)
|
||||||
return node
|
return node
|
||||||
match = _LEAFPAT.match(text)
|
match = _LEAFPAT.match(text)
|
||||||
if match is not None:
|
if match is not None:
|
||||||
node = match.group(1)
|
node = match.group(1)
|
||||||
graph.node(node, label=match.group(2), shape='box')
|
graph.node(node, label=match.group(2), **leaf_node_params)
|
||||||
return node
|
return node
|
||||||
raise ValueError('Unable to parse node: {0}'.format(text))
|
raise ValueError('Unable to parse node: {0}'.format(text))
|
||||||
|
|
||||||
@ -164,8 +163,8 @@ def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
|
|||||||
|
|
||||||
|
|
||||||
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
||||||
yes_color='#0000FF', no_color='#FF0000', **kwargs):
|
yes_color='#0000FF', no_color='#FF0000',
|
||||||
|
condition_node_params=None, leaf_node_params=None, **kwargs):
|
||||||
"""Convert specified tree to graphviz instance. IPython can automatically plot the
|
"""Convert specified tree to graphviz instance. IPython can automatically plot the
|
||||||
returned graphiz instance. Otherwise, you should call .render() method
|
returned graphiz instance. Otherwise, you should call .render() method
|
||||||
of the returned graphiz instance.
|
of the returned graphiz instance.
|
||||||
@ -184,6 +183,18 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
|||||||
Edge color when meets the node condition.
|
Edge color when meets the node condition.
|
||||||
no_color : str, default '#FF0000'
|
no_color : str, default '#FF0000'
|
||||||
Edge color when doesn't meet the node condition.
|
Edge color when doesn't meet the node condition.
|
||||||
|
condition_node_params : dict (optional)
|
||||||
|
condition node configuration,
|
||||||
|
{'shape':'box',
|
||||||
|
'style':'filled,rounded',
|
||||||
|
'fillcolor':'#78bceb'
|
||||||
|
}
|
||||||
|
leaf_node_params : dict (optional)
|
||||||
|
leaf node configuration
|
||||||
|
{'shape':'box',
|
||||||
|
'style':'filled',
|
||||||
|
'fillcolor':'#e48038'
|
||||||
|
}
|
||||||
kwargs :
|
kwargs :
|
||||||
Other keywords passed to graphviz graph_attr
|
Other keywords passed to graphviz graph_attr
|
||||||
|
|
||||||
@ -192,6 +203,11 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
|||||||
ax : matplotlib Axes
|
ax : matplotlib Axes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if condition_node_params is None:
|
||||||
|
condition_node_params = {}
|
||||||
|
if leaf_node_params is None:
|
||||||
|
leaf_node_params = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from graphviz import Digraph
|
from graphviz import Digraph
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -212,7 +228,9 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
|||||||
|
|
||||||
for i, text in enumerate(tree):
|
for i, text in enumerate(tree):
|
||||||
if text[0].isdigit():
|
if text[0].isdigit():
|
||||||
node = _parse_node(graph, text)
|
node = _parse_node(
|
||||||
|
graph, text, condition_node_params=condition_node_params,
|
||||||
|
leaf_node_params=leaf_node_params)
|
||||||
else:
|
else:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# 1st string must be node
|
# 1st string must be node
|
||||||
@ -256,7 +274,8 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
|||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(1, 1)
|
_, ax = plt.subplots(1, 1)
|
||||||
|
|
||||||
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs)
|
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees,
|
||||||
|
rankdir=rankdir, **kwargs)
|
||||||
|
|
||||||
s = BytesIO()
|
s = BytesIO()
|
||||||
s.write(g.pipe(format='png'))
|
s.write(g.pipe(format='png'))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user