ENH: Add visualization to python package

This commit is contained in:
sinhrks 2015-08-11 16:40:09 +09:00
parent a7202ee804
commit d24b36adf9
9 changed files with 311 additions and 2 deletions

View File

@ -40,6 +40,7 @@ on going at master
* Fix List * Fix List
- Fixed possible problem of poisson regression for R. - Fixed possible problem of poisson regression for R.
* Python module now throw exception instead of crash terminal when a parameter error happens. * Python module now throw exception instead of crash terminal when a parameter error happens.
* Python module now has importance plot and tree plot functions.
* Java api is ready for use * Java api is ready for use
* Added more test cases and continuous integration to make each build more robust * Added more test cases and continuous integration to make each build more robust
* Improvements in sklearn compatible module * Improvements in sklearn compatible module

View File

@ -44,3 +44,5 @@ List of Contributors
* [Jamie Hall](https://github.com/nerdcha) * [Jamie Hall](https://github.com/nerdcha)
- Jamie is the initial creator of xgboost sklearn modue. - Jamie is the initial creator of xgboost sklearn modue.
* [Yen-Ying Lee](https://github.com/white1033) * [Yen-Ying Lee](https://github.com/white1033)
* [Masaaki Horikoshi](https://github.com/sinhrks)
- Masaaki is the initial creator of xgboost python plotting module.

View File

@ -35,3 +35,13 @@ Scikit-Learn API
.. autoclass:: xgboost.XGBClassifier .. autoclass:: xgboost.XGBClassifier
:members: :members:
:show-inheritance: :show-inheritance:
Plotting API
------------
.. automodule:: xgboost.plotting
.. autofunction:: xgboost.plot_importance
.. autofunction:: xgboost.plot_tree
.. autofunction:: xgboost.to_graphviz

View File

@ -127,3 +127,27 @@ If early stopping is enabled during training, you can predict with the best iter
```python ```python
ypred = bst.predict(xgmat,ntree_limit=bst.best_iteration) ypred = bst.predict(xgmat,ntree_limit=bst.best_iteration)
``` ```
Plotting
--------
You can use plotting module to plot importance and output tree.
To plot importance, use ``plot_importance``. This function requires ``matplotlib`` to be installed.
```python
xgb.plot_importance(bst)
```
To output tree via ``matplotlib``, use ``plot_tree`` specifying ordinal number of the target tree.
This function requires ``graphviz`` and ``matplotlib``.
```python
xgb.plot_tree(bst, num_trees=2)
```
When you use ``IPython``, you can use ``to_graphviz`` function which converts the target tree to ``graphviz`` instance. ``graphviz`` instance is automatically rendered on ``IPython``.
```python
xgb.to_graphviz(bst, num_trees=2)
```

View File

@ -8,9 +8,11 @@ from __future__ import absolute_import
from .core import DMatrix, Booster from .core import DMatrix, Booster
from .training import train, cv from .training import train, cv
from .sklearn import XGBModel, XGBClassifier, XGBRegressor from .sklearn import XGBModel, XGBClassifier, XGBRegressor
from .plotting import plot_importance, plot_tree, to_graphviz
__version__ = '0.4' __version__ = '0.4'
__all__ = ['DMatrix', 'Booster', __all__ = ['DMatrix', 'Booster',
'train', 'cv', 'train', 'cv',
'XGBModel', 'XGBClassifier', 'XGBRegressor'] 'XGBModel', 'XGBClassifier', 'XGBRegressor',
'plot_importance', 'plot_tree', 'to_graphviz']

View File

@ -0,0 +1,227 @@
# coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name,
# pylint: disable=too-many-branches
"""Plotting Library."""
from __future__ import absolute_import
import re
import numpy as np
from .core import Booster
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
def plot_importance(booster, ax=None, height=0.2,
xlim=None, title='Feature importance',
xlabel='F score', ylabel='Features',
grid=True, **kwargs):
"""Plot importance based on fitted trees.
Parameters
----------
booster : Booster or dict
Booster instance, or dict taken by Booster.get_fscore()
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
height : float, default 0.2
Bar height, passed to ax.barh()
xlim : tuple, default None
Tuple passed to axes.xlim()
title : str, default "Feature importance"
Axes title. To disable, pass None.
xlabel : str, default "F score"
X axis title label. To disable, pass None.
ylabel : str, default "Features"
Y axis title label. To disable, pass None.
kwargs :
Other keywords passed to ax.barh()
Returns
-------
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError('You must install matplotlib to plot importance')
if isinstance(booster, Booster):
importance = booster.get_fscore()
elif isinstance(booster, dict):
importance = booster
else:
raise ValueError('tree must be Booster or dict instance')
if len(importance) == 0:
raise ValueError('Booster.get_fscore() results in empty')
tuples = [(k, importance[k]) for k in importance]
tuples = sorted(tuples, key=lambda x: x[1])
labels, values = zip(*tuples)
if ax is None:
_, ax = plt.subplots(1, 1)
ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs)
for x, y in zip(values, ylocs):
ax.text(x + 1, y, x, va='center')
ax.set_yticks(ylocs)
ax.set_yticklabels(labels)
if xlim is not None:
if not isinstance(xlim, tuple) or len(xlim, 2):
raise ValueError('xlim must be a tuple of 2 elements')
else:
xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim)
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax
_NODEPAT = re.compile(r'(\d+):\[(.+)\]')
_LEAFPAT = re.compile(r'(\d+):(leaf=.+)')
_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
def _parse_node(graph, text):
"""parse dumped node"""
match = _NODEPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), shape='circle')
return node
match = _LEAFPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), shape='box')
return node
raise ValueError('Unable to parse node: {0}'.format(text))
def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
"""parse dumped edge"""
match = _EDGEPAT.match(text)
if match is not None:
yes, no, missing = match.groups()
if yes == missing:
graph.edge(node, yes, label='yes, missing', color=yes_color)
graph.edge(node, no, label='no', color=no_color)
else:
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no, missing', color=no_color)
return
raise ValueError('Unable to parse edge: {0}'.format(text))
def to_graphviz(booster, num_trees=0, rankdir='UT',
yes_color='#0000FF', no_color='#FF0000', **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot the
returned graphiz instance. Otherwise, you shoud call .render() method
of the returned graphiz instance.
Parameters
----------
booster : Booster
Booster instance
num_trees : int, default 0
Specify the ordinal number of target tree
rankdir : str, default "UT"
Passed to graphiz via graph_attr
yes_color : str, default '#0000FF'
Edge color when meets the node condigion.
no_color : str, default '#FF0000'
Edge color when doesn't meet the node condigion.
kwargs :
Other keywords passed to graphviz graph_attr
Returns
-------
ax : matplotlib Axes
"""
try:
from graphviz import Digraph
except ImportError:
raise ImportError('You must install graphviz to plot tree')
if not isinstance(booster, Booster):
raise ValueError('booster must be Booster instance')
tree = booster.get_dump()[num_trees]
tree = tree.split()
kwargs = kwargs.copy()
kwargs.update({'rankdir': rankdir})
graph = Digraph(graph_attr=kwargs)
for i, text in enumerate(tree):
if text[0].isdigit():
node = _parse_node(graph, text)
else:
if i == 0:
# 1st string must be node
raise ValueError('Unable to parse given string as tree')
_parse_edge(graph, node, text, yes_color=yes_color,
no_color=no_color)
return graph
def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
"""Plot specified tree.
Parameters
----------
booster : Booster
Booster instance
num_trees : int, default 0
Specify the ordinal number of target tree
rankdir : str, default "UT"
Passed to graphiz via graph_attr
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
kwargs :
Other keywords passed to to_graphviz
Returns
-------
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
import matplotlib.image as image
except ImportError:
raise ImportError('You must install matplotlib to plot tree')
if ax is None:
_, ax = plt.subplots(1, 1)
g = to_graphviz(booster, num_trees=num_trees, rankdir=rankdir, **kwargs)
s = StringIO()
s.write(g.pipe(format='png'))
s.seek(0)
img = image.imread(s)
ax.imshow(img)
ax.axis('off')
return ax

View File

@ -7,7 +7,7 @@ fi
brew update brew update
if [ ${TASK} == "python-package" ]; then if [ ${TASK} == "python-package" ]; then
brew install python git brew install python git graphviz
easy_install pip easy_install pip
pip install numpy scipy nose pip install numpy scipy nose
fi fi

View File

@ -34,6 +34,8 @@ if [ ${TASK} == "R-package" ]; then
fi fi
if [ ${TASK} == "python-package" ]; then if [ ${TASK} == "python-package" ]; then
sudo apt-get install graphviz
sudo pip install matplotlib graphviz
make all CXX=${CXX} || exit -1 make all CXX=${CXX} || exit -1
nosetests tests/python || exit -1 nosetests tests/python || exit -1
fi fi

View File

@ -29,3 +29,44 @@ def test_basic():
# assert they are the same # assert they are the same
assert np.sum(np.abs(preds2-preds)) == 0 assert np.sum(np.abs(preds2-preds)) == 0
def test_plotting():
bst2 = xgb.Booster(model_file='xgb.model')
# plotting
from matplotlib.axes import Axes
from graphviz import Digraph
ax = xgb.plot_importance(bst2)
assert isinstance(ax, Axes)
assert ax.get_title() == 'Feature importance'
assert ax.get_xlabel() == 'F score'
assert ax.get_ylabel() == 'Features'
assert len(ax.patches) == 4
ax = xgb.plot_importance(bst2, color='r',
title='t', xlabel='x', ylabel='y')
assert isinstance(ax, Axes)
assert ax.get_title() == 't'
assert ax.get_xlabel() == 'x'
assert ax.get_ylabel() == 'y'
assert len(ax.patches) == 4
for p in ax.patches:
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
title=None, xlabel=None, ylabel=None)
assert isinstance(ax, Axes)
assert ax.get_title() == ''
assert ax.get_xlabel() == ''
assert ax.get_ylabel() == ''
assert len(ax.patches) == 4
assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Digraph)
ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes)