From 0cd326c1bc369f1ed395f5fa59cc512166ca6ef0 Mon Sep 17 00:00:00 2001 From: Joey Gao <1783198484@qq.com> Date: Fri, 16 Nov 2018 12:29:37 +0800 Subject: [PATCH] Add parameter to make node type configurable in plot tree (#3859) * add parameters 'conditionNodeParams' and 'leafNodeParams' to function `to_graphviz` enable to configure node type --- python-package/xgboost/plotting.py | 35 +++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index 99bc31675..8e69dc10c 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -16,7 +16,6 @@ def plot_importance(booster, ax=None, height=0.2, xlabel='F score', ylabel='Features', importance_type='weight', max_num_features=None, grid=True, show_values=True, **kwargs): - """Plot importance based on fitted trees. Parameters @@ -124,17 +123,17 @@ _EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)') _EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)') -def _parse_node(graph, text): +def _parse_node(graph, text, condition_node_params, leaf_node_params): """parse dumped node""" match = _NODEPAT.match(text) if match is not None: node = match.group(1) - graph.node(node, label=match.group(2), shape='circle') + graph.node(node, label=match.group(2), **condition_node_params) return node match = _LEAFPAT.match(text) if match is not None: node = match.group(1) - graph.node(node, label=match.group(2), shape='box') + graph.node(node, label=match.group(2), **leaf_node_params) return node raise ValueError('Unable to parse node: {0}'.format(text)) @@ -164,8 +163,8 @@ def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'): def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT', - yes_color='#0000FF', no_color='#FF0000', **kwargs): - + yes_color='#0000FF', no_color='#FF0000', + condition_node_params=None, leaf_node_params=None, **kwargs): """Convert specified tree to graphviz instance. IPython can automatically plot the returned graphiz instance. Otherwise, you should call .render() method of the returned graphiz instance. @@ -184,6 +183,18 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT', Edge color when meets the node condition. no_color : str, default '#FF0000' Edge color when doesn't meet the node condition. + condition_node_params : dict (optional) + condition node configuration, + {'shape':'box', + 'style':'filled,rounded', + 'fillcolor':'#78bceb' + } + leaf_node_params : dict (optional) + leaf node configuration + {'shape':'box', + 'style':'filled', + 'fillcolor':'#e48038' + } kwargs : Other keywords passed to graphviz graph_attr @@ -192,6 +203,11 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT', ax : matplotlib Axes """ + if condition_node_params is None: + condition_node_params = {} + if leaf_node_params is None: + leaf_node_params = {} + try: from graphviz import Digraph except ImportError: @@ -212,7 +228,9 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT', for i, text in enumerate(tree): if text[0].isdigit(): - node = _parse_node(graph, text) + node = _parse_node( + graph, text, condition_node_params=condition_node_params, + leaf_node_params=leaf_node_params) else: if i == 0: # 1st string must be node @@ -256,7 +274,8 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs): if ax is None: _, ax = plt.subplots(1, 1) - g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs) + g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, + rankdir=rankdir, **kwargs) s = BytesIO() s.write(g.pipe(format='png'))