ENH: Add visualization to python package

This commit is contained in:
sinhrks 2015-08-11 16:40:09 +09:00
parent a7202ee804
commit d24b36adf9
9 changed files with 311 additions and 2 deletions

View File

@ -40,6 +40,7 @@ on going at master
* Fix List
- Fixed possible problem of poisson regression for R.
* Python module now throw exception instead of crash terminal when a parameter error happens.
* Python module now has importance plot and tree plot functions.
* Java api is ready for use
* Added more test cases and continuous integration to make each build more robust
* Improvements in sklearn compatible module

View File

@ -44,3 +44,5 @@ List of Contributors
* [Jamie Hall](https://github.com/nerdcha)
- Jamie is the initial creator of xgboost sklearn modue.
* [Yen-Ying Lee](https://github.com/white1033)
* [Masaaki Horikoshi](https://github.com/sinhrks)
- Masaaki is the initial creator of xgboost python plotting module.

View File

@ -35,3 +35,13 @@ Scikit-Learn API
.. autoclass:: xgboost.XGBClassifier
:members:
:show-inheritance:
Plotting API
------------
.. automodule:: xgboost.plotting
.. autofunction:: xgboost.plot_importance
.. autofunction:: xgboost.plot_tree
.. autofunction:: xgboost.to_graphviz

View File

@ -127,3 +127,27 @@ If early stopping is enabled during training, you can predict with the best iter
```python
ypred = bst.predict(xgmat,ntree_limit=bst.best_iteration)
```
Plotting
--------
You can use plotting module to plot importance and output tree.
To plot importance, use ``plot_importance``. This function requires ``matplotlib`` to be installed.
```python
xgb.plot_importance(bst)
```
To output tree via ``matplotlib``, use ``plot_tree`` specifying ordinal number of the target tree.
This function requires ``graphviz`` and ``matplotlib``.
```python
xgb.plot_tree(bst, num_trees=2)
```
When you use ``IPython``, you can use ``to_graphviz`` function which converts the target tree to ``graphviz`` instance. ``graphviz`` instance is automatically rendered on ``IPython``.
```python
xgb.to_graphviz(bst, num_trees=2)
```

View File

@ -8,9 +8,11 @@ from __future__ import absolute_import
from .core import DMatrix, Booster
from .training import train, cv
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
from .plotting import plot_importance, plot_tree, to_graphviz
__version__ = '0.4'
__all__ = ['DMatrix', 'Booster',
'train', 'cv',
'XGBModel', 'XGBClassifier', 'XGBRegressor']
'XGBModel', 'XGBClassifier', 'XGBRegressor',
'plot_importance', 'plot_tree', 'to_graphviz']

View File

@ -0,0 +1,227 @@
# coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name,
# pylint: disable=too-many-branches
"""Plotting Library."""
from __future__ import absolute_import
import re
import numpy as np
from .core import Booster
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
def plot_importance(booster, ax=None, height=0.2,
xlim=None, title='Feature importance',
xlabel='F score', ylabel='Features',
grid=True, **kwargs):
"""Plot importance based on fitted trees.
Parameters
----------
booster : Booster or dict
Booster instance, or dict taken by Booster.get_fscore()
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
height : float, default 0.2
Bar height, passed to ax.barh()
xlim : tuple, default None
Tuple passed to axes.xlim()
title : str, default "Feature importance"
Axes title. To disable, pass None.
xlabel : str, default "F score"
X axis title label. To disable, pass None.
ylabel : str, default "Features"
Y axis title label. To disable, pass None.
kwargs :
Other keywords passed to ax.barh()
Returns
-------
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError('You must install matplotlib to plot importance')
if isinstance(booster, Booster):
importance = booster.get_fscore()
elif isinstance(booster, dict):
importance = booster
else:
raise ValueError('tree must be Booster or dict instance')
if len(importance) == 0:
raise ValueError('Booster.get_fscore() results in empty')
tuples = [(k, importance[k]) for k in importance]
tuples = sorted(tuples, key=lambda x: x[1])
labels, values = zip(*tuples)
if ax is None:
_, ax = plt.subplots(1, 1)
ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs)
for x, y in zip(values, ylocs):
ax.text(x + 1, y, x, va='center')
ax.set_yticks(ylocs)
ax.set_yticklabels(labels)
if xlim is not None:
if not isinstance(xlim, tuple) or len(xlim, 2):
raise ValueError('xlim must be a tuple of 2 elements')
else:
xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim)
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax
_NODEPAT = re.compile(r'(\d+):\[(.+)\]')
_LEAFPAT = re.compile(r'(\d+):(leaf=.+)')
_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
def _parse_node(graph, text):
"""parse dumped node"""
match = _NODEPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), shape='circle')
return node
match = _LEAFPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), shape='box')
return node
raise ValueError('Unable to parse node: {0}'.format(text))
def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
"""parse dumped edge"""
match = _EDGEPAT.match(text)
if match is not None:
yes, no, missing = match.groups()
if yes == missing:
graph.edge(node, yes, label='yes, missing', color=yes_color)
graph.edge(node, no, label='no', color=no_color)
else:
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no, missing', color=no_color)
return
raise ValueError('Unable to parse edge: {0}'.format(text))
def to_graphviz(booster, num_trees=0, rankdir='UT',
yes_color='#0000FF', no_color='#FF0000', **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot the
returned graphiz instance. Otherwise, you shoud call .render() method
of the returned graphiz instance.
Parameters
----------
booster : Booster
Booster instance
num_trees : int, default 0
Specify the ordinal number of target tree
rankdir : str, default "UT"
Passed to graphiz via graph_attr
yes_color : str, default '#0000FF'
Edge color when meets the node condigion.
no_color : str, default '#FF0000'
Edge color when doesn't meet the node condigion.
kwargs :
Other keywords passed to graphviz graph_attr
Returns
-------
ax : matplotlib Axes
"""
try:
from graphviz import Digraph
except ImportError:
raise ImportError('You must install graphviz to plot tree')
if not isinstance(booster, Booster):
raise ValueError('booster must be Booster instance')
tree = booster.get_dump()[num_trees]
tree = tree.split()
kwargs = kwargs.copy()
kwargs.update({'rankdir': rankdir})
graph = Digraph(graph_attr=kwargs)
for i, text in enumerate(tree):
if text[0].isdigit():
node = _parse_node(graph, text)
else:
if i == 0:
# 1st string must be node
raise ValueError('Unable to parse given string as tree')
_parse_edge(graph, node, text, yes_color=yes_color,
no_color=no_color)
return graph
def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
"""Plot specified tree.
Parameters
----------
booster : Booster
Booster instance
num_trees : int, default 0
Specify the ordinal number of target tree
rankdir : str, default "UT"
Passed to graphiz via graph_attr
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
kwargs :
Other keywords passed to to_graphviz
Returns
-------
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
import matplotlib.image as image
except ImportError:
raise ImportError('You must install matplotlib to plot tree')
if ax is None:
_, ax = plt.subplots(1, 1)
g = to_graphviz(booster, num_trees=num_trees, rankdir=rankdir, **kwargs)
s = StringIO()
s.write(g.pipe(format='png'))
s.seek(0)
img = image.imread(s)
ax.imshow(img)
ax.axis('off')
return ax

View File

@ -7,7 +7,7 @@ fi
brew update
if [ ${TASK} == "python-package" ]; then
brew install python git
brew install python git graphviz
easy_install pip
pip install numpy scipy nose
fi

View File

@ -34,6 +34,8 @@ if [ ${TASK} == "R-package" ]; then
fi
if [ ${TASK} == "python-package" ]; then
sudo apt-get install graphviz
sudo pip install matplotlib graphviz
make all CXX=${CXX} || exit -1
nosetests tests/python || exit -1
fi

View File

@ -29,3 +29,44 @@ def test_basic():
# assert they are the same
assert np.sum(np.abs(preds2-preds)) == 0
def test_plotting():
bst2 = xgb.Booster(model_file='xgb.model')
# plotting
from matplotlib.axes import Axes
from graphviz import Digraph
ax = xgb.plot_importance(bst2)
assert isinstance(ax, Axes)
assert ax.get_title() == 'Feature importance'
assert ax.get_xlabel() == 'F score'
assert ax.get_ylabel() == 'Features'
assert len(ax.patches) == 4
ax = xgb.plot_importance(bst2, color='r',
title='t', xlabel='x', ylabel='y')
assert isinstance(ax, Axes)
assert ax.get_title() == 't'
assert ax.get_xlabel() == 'x'
assert ax.get_ylabel() == 'y'
assert len(ax.patches) == 4
for p in ax.patches:
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
title=None, xlabel=None, ylabel=None)
assert isinstance(ax, Axes)
assert ax.get_title() == ''
assert ax.get_xlabel() == ''
assert ax.get_ylabel() == ''
assert len(ax.patches) == 4
assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Digraph)
ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes)