Implement tree model dump with code generator. (#4602)
* Implement tree model dump with a code generator. * Split up generators. * Implement graphviz generator. * Use pattern matching. * [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package. Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1419,7 +1419,7 @@ class Booster(object):
|
||||
with_stats : bool, optional
|
||||
Controls whether the split statistics are output.
|
||||
dump_format : string, optional
|
||||
Format of model dump. Can be 'text' or 'json'.
|
||||
Format of model dump. Can be 'text', 'json' or 'dot'.
|
||||
"""
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name,
|
||||
# pylint: disable=too-many-branches
|
||||
# coding: utf-8
|
||||
"""Plotting Library."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import re
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
from .core import Booster
|
||||
@@ -61,7 +58,8 @@ def plot_importance(booster, ax=None, height=0.2,
|
||||
raise ImportError('You must install matplotlib to plot importance')
|
||||
|
||||
if isinstance(booster, XGBModel):
|
||||
importance = booster.get_booster().get_score(importance_type=importance_type)
|
||||
importance = booster.get_booster().get_score(
|
||||
importance_type=importance_type)
|
||||
elif isinstance(booster, Booster):
|
||||
importance = booster.get_score(importance_type=importance_type)
|
||||
elif isinstance(booster, dict):
|
||||
@@ -117,56 +115,11 @@ def plot_importance(booster, ax=None, height=0.2,
|
||||
return ax
|
||||
|
||||
|
||||
_NODEPAT = re.compile(r'(\d+):\[(.+)\]')
|
||||
_LEAFPAT = re.compile(r'(\d+):(leaf=.+)')
|
||||
_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
|
||||
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
|
||||
|
||||
|
||||
def _parse_node(graph, text, condition_node_params, leaf_node_params):
|
||||
"""parse dumped node"""
|
||||
match = _NODEPAT.match(text)
|
||||
if match is not None:
|
||||
node = match.group(1)
|
||||
graph.node(node, label=match.group(2), **condition_node_params)
|
||||
return node
|
||||
match = _LEAFPAT.match(text)
|
||||
if match is not None:
|
||||
node = match.group(1)
|
||||
graph.node(node, label=match.group(2), **leaf_node_params)
|
||||
return node
|
||||
raise ValueError('Unable to parse node: {0}'.format(text))
|
||||
|
||||
|
||||
def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
|
||||
"""parse dumped edge"""
|
||||
try:
|
||||
match = _EDGEPAT.match(text)
|
||||
if match is not None:
|
||||
yes, no, missing = match.groups()
|
||||
if yes == missing:
|
||||
graph.edge(node, yes, label='yes, missing', color=yes_color)
|
||||
graph.edge(node, no, label='no', color=no_color)
|
||||
else:
|
||||
graph.edge(node, yes, label='yes', color=yes_color)
|
||||
graph.edge(node, no, label='no, missing', color=no_color)
|
||||
return
|
||||
except ValueError:
|
||||
pass
|
||||
match = _EDGEPAT2.match(text)
|
||||
if match is not None:
|
||||
yes, no = match.groups()
|
||||
graph.edge(node, yes, label='yes', color=yes_color)
|
||||
graph.edge(node, no, label='no', color=no_color)
|
||||
return
|
||||
raise ValueError('Unable to parse edge: {0}'.format(text))
|
||||
|
||||
|
||||
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
||||
yes_color='#0000FF', no_color='#FF0000',
|
||||
def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
|
||||
yes_color=None, no_color=None,
|
||||
condition_node_params=None, leaf_node_params=None, **kwargs):
|
||||
"""Convert specified tree to graphviz instance. IPython can automatically plot the
|
||||
returned graphiz instance. Otherwise, you should call .render() method
|
||||
"""Convert specified tree to graphviz instance. IPython can automatically plot
|
||||
the returned graphiz instance. Otherwise, you should call .render() method
|
||||
of the returned graphiz instance.
|
||||
|
||||
Parameters
|
||||
@@ -184,64 +137,77 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
||||
no_color : str, default '#FF0000'
|
||||
Edge color when doesn't meet the node condition.
|
||||
condition_node_params : dict (optional)
|
||||
condition node configuration,
|
||||
{'shape':'box',
|
||||
'style':'filled,rounded',
|
||||
'fillcolor':'#78bceb'}
|
||||
Condition node configuration for for graphviz. Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{'shape': 'box',
|
||||
'style': 'filled,rounded',
|
||||
'fillcolor': '#78bceb'}
|
||||
|
||||
leaf_node_params : dict (optional)
|
||||
leaf node configuration
|
||||
{'shape':'box',
|
||||
'style':'filled',
|
||||
'fillcolor':'#e48038'}
|
||||
Leaf node configuration for graphviz. Example:
|
||||
|
||||
kwargs :
|
||||
Other keywords passed to graphviz graph_attr
|
||||
.. code-block:: python
|
||||
|
||||
{'shape': 'box',
|
||||
'style': 'filled',
|
||||
'fillcolor': '#e48038'}
|
||||
|
||||
kwargs : Other keywords passed to graphviz graph_attr, E.g.:
|
||||
``graph [ {key} = {value} ]``
|
||||
|
||||
Returns
|
||||
-------
|
||||
ax : matplotlib Axes
|
||||
graph: graphviz.Source
|
||||
|
||||
"""
|
||||
|
||||
if condition_node_params is None:
|
||||
condition_node_params = {}
|
||||
if leaf_node_params is None:
|
||||
leaf_node_params = {}
|
||||
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
from graphviz import Source
|
||||
except ImportError:
|
||||
raise ImportError('You must install graphviz to plot tree')
|
||||
|
||||
if not isinstance(booster, (Booster, XGBModel)):
|
||||
raise ValueError('booster must be Booster or XGBModel instance')
|
||||
|
||||
if isinstance(booster, XGBModel):
|
||||
booster = booster.get_booster()
|
||||
|
||||
tree = booster.get_dump(fmap=fmap)[num_trees]
|
||||
tree = tree.split()
|
||||
# squash everything back into kwargs again for compatibility
|
||||
parameters = 'dot'
|
||||
extra = {}
|
||||
for key, value in kwargs.items():
|
||||
extra[key] = value
|
||||
|
||||
kwargs = kwargs.copy()
|
||||
kwargs.update({'rankdir': rankdir})
|
||||
graph = Digraph(graph_attr=kwargs)
|
||||
|
||||
for i, text in enumerate(tree):
|
||||
if text[0].isdigit():
|
||||
node = _parse_node(
|
||||
graph, text, condition_node_params=condition_node_params,
|
||||
leaf_node_params=leaf_node_params)
|
||||
if rankdir is not None:
|
||||
kwargs['graph_attrs'] = {}
|
||||
kwargs['graph_attrs']['rankdir'] = rankdir
|
||||
for key, value in extra.items():
|
||||
if 'graph_attrs' in kwargs.keys():
|
||||
kwargs['graph_attrs'][key] = value
|
||||
else:
|
||||
if i == 0:
|
||||
# 1st string must be node
|
||||
raise ValueError('Unable to parse given string as tree')
|
||||
_parse_edge(graph, node, text, yes_color=yes_color,
|
||||
no_color=no_color)
|
||||
kwargs['graph_attrs'] = {}
|
||||
del kwargs[key]
|
||||
|
||||
return graph
|
||||
if yes_color is not None or no_color is not None:
|
||||
kwargs['edge'] = {}
|
||||
if yes_color is not None:
|
||||
kwargs['edge']['yes_color'] = yes_color
|
||||
if no_color is not None:
|
||||
kwargs['edge']['no_color'] = no_color
|
||||
|
||||
if condition_node_params is not None:
|
||||
kwargs['condition_node_params'] = condition_node_params
|
||||
if leaf_node_params is not None:
|
||||
kwargs['leaf_node_params'] = leaf_node_params
|
||||
|
||||
if kwargs:
|
||||
parameters += ':'
|
||||
parameters += str(kwargs)
|
||||
tree = booster.get_dump(
|
||||
fmap=fmap,
|
||||
dump_format=parameters)[num_trees]
|
||||
g = Source(tree)
|
||||
return g
|
||||
|
||||
|
||||
def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs):
|
||||
"""Plot specified tree.
|
||||
|
||||
Parameters
|
||||
@@ -252,7 +218,7 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
The name of feature map file
|
||||
num_trees : int, default 0
|
||||
Specify the ordinal number of target tree
|
||||
rankdir : str, default "UT"
|
||||
rankdir : str, default "TB"
|
||||
Passed to graphiz via graph_attr
|
||||
ax : matplotlib Axes, default None
|
||||
Target axes instance. If None, new figure and axes will be created.
|
||||
@@ -264,18 +230,17 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
ax : matplotlib Axes
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.image as image
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib import image
|
||||
except ImportError:
|
||||
raise ImportError('You must install matplotlib to plot tree')
|
||||
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 1)
|
||||
|
||||
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees,
|
||||
rankdir=rankdir, **kwargs)
|
||||
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
|
||||
**kwargs)
|
||||
|
||||
s = BytesIO()
|
||||
s.write(g.pipe(format='png'))
|
||||
|
||||
Reference in New Issue
Block a user