Jiaming Yuan 8bdf15120a
Implement tree model dump with code generator. (#4602)
* Implement tree model dump with a code generator.

* Split up generators.
* Implement graphviz generator.
* Use pattern matching.

* [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package.


Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
2019-06-26 15:20:44 +08:00

253 lines
7.9 KiB
Python

# pylint: disable=too-many-locals, too-many-arguments, invalid-name,
# pylint: disable=too-many-branches
# coding: utf-8
"""Plotting Library."""
from io import BytesIO
import numpy as np
from .core import Booster
from .sklearn import XGBModel
def plot_importance(booster, ax=None, height=0.2,
xlim=None, ylim=None, title='Feature importance',
xlabel='F score', ylabel='Features',
importance_type='weight', max_num_features=None,
grid=True, show_values=True, **kwargs):
"""Plot importance based on fitted trees.
Parameters
----------
booster : Booster, XGBModel or dict
Booster or XGBModel instance, or dict taken by Booster.get_fscore()
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
grid : bool, Turn the axes grids on or off. Default is True (On).
importance_type : str, default "weight"
How the importance is calculated: either "weight", "gain", or "cover"
* "weight" is the number of times a feature appears in a tree
* "gain" is the average gain of splits which use the feature
* "cover" is the average coverage of splits which use the feature
where coverage is defined as the number of samples affected by the split
max_num_features : int, default None
Maximum number of top features displayed on plot. If None, all features will be displayed.
height : float, default 0.2
Bar height, passed to ax.barh()
xlim : tuple, default None
Tuple passed to axes.xlim()
ylim : tuple, default None
Tuple passed to axes.ylim()
title : str, default "Feature importance"
Axes title. To disable, pass None.
xlabel : str, default "F score"
X axis title label. To disable, pass None.
ylabel : str, default "Features"
Y axis title label. To disable, pass None.
show_values : bool, default True
Show values on plot. To disable, pass False.
kwargs :
Other keywords passed to ax.barh()
Returns
-------
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError('You must install matplotlib to plot importance')
if isinstance(booster, XGBModel):
importance = booster.get_booster().get_score(
importance_type=importance_type)
elif isinstance(booster, Booster):
importance = booster.get_score(importance_type=importance_type)
elif isinstance(booster, dict):
importance = booster
else:
raise ValueError('tree must be Booster, XGBModel or dict instance')
if not importance:
raise ValueError('Booster.get_score() results in empty')
tuples = [(k, importance[k]) for k in importance]
if max_num_features is not None:
# pylint: disable=invalid-unary-operand-type
tuples = sorted(tuples, key=lambda x: x[1])[-max_num_features:]
else:
tuples = sorted(tuples, key=lambda x: x[1])
labels, values = zip(*tuples)
if ax is None:
_, ax = plt.subplots(1, 1)
ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs)
if show_values is True:
for x, y in zip(values, ylocs):
ax.text(x + 1, y, x, va='center')
ax.set_yticks(ylocs)
ax.set_yticklabels(labels)
if xlim is not None:
if not isinstance(xlim, tuple) or len(xlim) != 2:
raise ValueError('xlim must be a tuple of 2 elements')
else:
xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim)
if ylim is not None:
if not isinstance(ylim, tuple) or len(ylim) != 2:
raise ValueError('ylim must be a tuple of 2 elements')
else:
ylim = (-1, len(values))
ax.set_ylim(ylim)
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax
def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
yes_color=None, no_color=None,
condition_node_params=None, leaf_node_params=None, **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot
the returned graphiz instance. Otherwise, you should call .render() method
of the returned graphiz instance.
Parameters
----------
booster : Booster, XGBModel
Booster or XGBModel instance
fmap: str (optional)
The name of feature map file
num_trees : int, default 0
Specify the ordinal number of target tree
rankdir : str, default "UT"
Passed to graphiz via graph_attr
yes_color : str, default '#0000FF'
Edge color when meets the node condition.
no_color : str, default '#FF0000'
Edge color when doesn't meet the node condition.
condition_node_params : dict (optional)
Condition node configuration for for graphviz. Example:
.. code-block:: python
{'shape': 'box',
'style': 'filled,rounded',
'fillcolor': '#78bceb'}
leaf_node_params : dict (optional)
Leaf node configuration for graphviz. Example:
.. code-block:: python
{'shape': 'box',
'style': 'filled',
'fillcolor': '#e48038'}
kwargs : Other keywords passed to graphviz graph_attr, E.g.:
``graph [ {key} = {value} ]``
Returns
-------
graph: graphviz.Source
"""
try:
from graphviz import Source
except ImportError:
raise ImportError('You must install graphviz to plot tree')
if isinstance(booster, XGBModel):
booster = booster.get_booster()
# squash everything back into kwargs again for compatibility
parameters = 'dot'
extra = {}
for key, value in kwargs.items():
extra[key] = value
if rankdir is not None:
kwargs['graph_attrs'] = {}
kwargs['graph_attrs']['rankdir'] = rankdir
for key, value in extra.items():
if 'graph_attrs' in kwargs.keys():
kwargs['graph_attrs'][key] = value
else:
kwargs['graph_attrs'] = {}
del kwargs[key]
if yes_color is not None or no_color is not None:
kwargs['edge'] = {}
if yes_color is not None:
kwargs['edge']['yes_color'] = yes_color
if no_color is not None:
kwargs['edge']['no_color'] = no_color
if condition_node_params is not None:
kwargs['condition_node_params'] = condition_node_params
if leaf_node_params is not None:
kwargs['leaf_node_params'] = leaf_node_params
if kwargs:
parameters += ':'
parameters += str(kwargs)
tree = booster.get_dump(
fmap=fmap,
dump_format=parameters)[num_trees]
g = Source(tree)
return g
def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs):
"""Plot specified tree.
Parameters
----------
booster : Booster, XGBModel
Booster or XGBModel instance
fmap: str (optional)
The name of feature map file
num_trees : int, default 0
Specify the ordinal number of target tree
rankdir : str, default "TB"
Passed to graphiz via graph_attr
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
kwargs :
Other keywords passed to to_graphviz
Returns
-------
ax : matplotlib Axes
"""
try:
from matplotlib import pyplot as plt
from matplotlib import image
except ImportError:
raise ImportError('You must install matplotlib to plot tree')
if ax is None:
_, ax = plt.subplots(1, 1)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
**kwargs)
s = BytesIO()
s.write(g.pipe(format='png'))
s.seek(0)
img = image.imread(s)
ax.imshow(img)
ax.axis('off')
return ax