Implement tree model dump with code generator. (#4602)

* Implement tree model dump with a code generator.

* Split up generators.
* Implement graphviz generator.
* Use pattern matching.

* [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package.


Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2019-06-26 15:20:44 +08:00
committed by GitHub
parent fe2de6f415
commit 8bdf15120a
11 changed files with 802 additions and 264 deletions

View File

@@ -1419,7 +1419,7 @@ class Booster(object):
with_stats : bool, optional
Controls whether the split statistics are output.
dump_format : string, optional
Format of model dump. Can be 'text' or 'json'.
Format of model dump. Can be 'text', 'json' or 'dot'.
"""
length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)()

View File

@@ -1,10 +1,7 @@
# coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name,
# pylint: disable=too-many-branches
# coding: utf-8
"""Plotting Library."""
from __future__ import absolute_import
import re
from io import BytesIO
import numpy as np
from .core import Booster
@@ -61,7 +58,8 @@ def plot_importance(booster, ax=None, height=0.2,
raise ImportError('You must install matplotlib to plot importance')
if isinstance(booster, XGBModel):
importance = booster.get_booster().get_score(importance_type=importance_type)
importance = booster.get_booster().get_score(
importance_type=importance_type)
elif isinstance(booster, Booster):
importance = booster.get_score(importance_type=importance_type)
elif isinstance(booster, dict):
@@ -117,56 +115,11 @@ def plot_importance(booster, ax=None, height=0.2,
return ax
_NODEPAT = re.compile(r'(\d+):\[(.+)\]')
_LEAFPAT = re.compile(r'(\d+):(leaf=.+)')
_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
def _parse_node(graph, text, condition_node_params, leaf_node_params):
"""parse dumped node"""
match = _NODEPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), **condition_node_params)
return node
match = _LEAFPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), **leaf_node_params)
return node
raise ValueError('Unable to parse node: {0}'.format(text))
def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
"""parse dumped edge"""
try:
match = _EDGEPAT.match(text)
if match is not None:
yes, no, missing = match.groups()
if yes == missing:
graph.edge(node, yes, label='yes, missing', color=yes_color)
graph.edge(node, no, label='no', color=no_color)
else:
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no, missing', color=no_color)
return
except ValueError:
pass
match = _EDGEPAT2.match(text)
if match is not None:
yes, no = match.groups()
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no', color=no_color)
return
raise ValueError('Unable to parse edge: {0}'.format(text))
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
yes_color='#0000FF', no_color='#FF0000',
def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
yes_color=None, no_color=None,
condition_node_params=None, leaf_node_params=None, **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot the
returned graphiz instance. Otherwise, you should call .render() method
"""Convert specified tree to graphviz instance. IPython can automatically plot
the returned graphiz instance. Otherwise, you should call .render() method
of the returned graphiz instance.
Parameters
@@ -184,64 +137,77 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
no_color : str, default '#FF0000'
Edge color when doesn't meet the node condition.
condition_node_params : dict (optional)
condition node configuration,
{'shape':'box',
'style':'filled,rounded',
'fillcolor':'#78bceb'}
Condition node configuration for for graphviz. Example:
.. code-block:: python
{'shape': 'box',
'style': 'filled,rounded',
'fillcolor': '#78bceb'}
leaf_node_params : dict (optional)
leaf node configuration
{'shape':'box',
'style':'filled',
'fillcolor':'#e48038'}
Leaf node configuration for graphviz. Example:
kwargs :
Other keywords passed to graphviz graph_attr
.. code-block:: python
{'shape': 'box',
'style': 'filled',
'fillcolor': '#e48038'}
kwargs : Other keywords passed to graphviz graph_attr, E.g.:
``graph [ {key} = {value} ]``
Returns
-------
ax : matplotlib Axes
graph: graphviz.Source
"""
if condition_node_params is None:
condition_node_params = {}
if leaf_node_params is None:
leaf_node_params = {}
try:
from graphviz import Digraph
from graphviz import Source
except ImportError:
raise ImportError('You must install graphviz to plot tree')
if not isinstance(booster, (Booster, XGBModel)):
raise ValueError('booster must be Booster or XGBModel instance')
if isinstance(booster, XGBModel):
booster = booster.get_booster()
tree = booster.get_dump(fmap=fmap)[num_trees]
tree = tree.split()
# squash everything back into kwargs again for compatibility
parameters = 'dot'
extra = {}
for key, value in kwargs.items():
extra[key] = value
kwargs = kwargs.copy()
kwargs.update({'rankdir': rankdir})
graph = Digraph(graph_attr=kwargs)
for i, text in enumerate(tree):
if text[0].isdigit():
node = _parse_node(
graph, text, condition_node_params=condition_node_params,
leaf_node_params=leaf_node_params)
if rankdir is not None:
kwargs['graph_attrs'] = {}
kwargs['graph_attrs']['rankdir'] = rankdir
for key, value in extra.items():
if 'graph_attrs' in kwargs.keys():
kwargs['graph_attrs'][key] = value
else:
if i == 0:
# 1st string must be node
raise ValueError('Unable to parse given string as tree')
_parse_edge(graph, node, text, yes_color=yes_color,
no_color=no_color)
kwargs['graph_attrs'] = {}
del kwargs[key]
return graph
if yes_color is not None or no_color is not None:
kwargs['edge'] = {}
if yes_color is not None:
kwargs['edge']['yes_color'] = yes_color
if no_color is not None:
kwargs['edge']['no_color'] = no_color
if condition_node_params is not None:
kwargs['condition_node_params'] = condition_node_params
if leaf_node_params is not None:
kwargs['leaf_node_params'] = leaf_node_params
if kwargs:
parameters += ':'
parameters += str(kwargs)
tree = booster.get_dump(
fmap=fmap,
dump_format=parameters)[num_trees]
g = Source(tree)
return g
def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs):
"""Plot specified tree.
Parameters
@@ -252,7 +218,7 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
The name of feature map file
num_trees : int, default 0
Specify the ordinal number of target tree
rankdir : str, default "UT"
rankdir : str, default "TB"
Passed to graphiz via graph_attr
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
@@ -264,18 +230,17 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
import matplotlib.image as image
from matplotlib import pyplot as plt
from matplotlib import image
except ImportError:
raise ImportError('You must install matplotlib to plot tree')
if ax is None:
_, ax = plt.subplots(1, 1)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees,
rankdir=rankdir, **kwargs)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
**kwargs)
s = BytesIO()
s.write(g.pipe(format='png'))