ENH: Add visualization to python package
This commit is contained in:
parent
a7202ee804
commit
d24b36adf9
@ -40,6 +40,7 @@ on going at master
|
||||
* Fix List
|
||||
- Fixed possible problem of poisson regression for R.
|
||||
* Python module now throw exception instead of crash terminal when a parameter error happens.
|
||||
* Python module now has importance plot and tree plot functions.
|
||||
* Java api is ready for use
|
||||
* Added more test cases and continuous integration to make each build more robust
|
||||
* Improvements in sklearn compatible module
|
||||
@ -44,3 +44,5 @@ List of Contributors
|
||||
* [Jamie Hall](https://github.com/nerdcha)
|
||||
- Jamie is the initial creator of xgboost sklearn modue.
|
||||
* [Yen-Ying Lee](https://github.com/white1033)
|
||||
* [Masaaki Horikoshi](https://github.com/sinhrks)
|
||||
- Masaaki is the initial creator of xgboost python plotting module.
|
||||
@ -35,3 +35,13 @@ Scikit-Learn API
|
||||
.. autoclass:: xgboost.XGBClassifier
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Plotting API
|
||||
------------
|
||||
.. automodule:: xgboost.plotting
|
||||
|
||||
.. autofunction:: xgboost.plot_importance
|
||||
|
||||
.. autofunction:: xgboost.plot_tree
|
||||
|
||||
.. autofunction:: xgboost.to_graphviz
|
||||
|
||||
@ -127,3 +127,27 @@ If early stopping is enabled during training, you can predict with the best iter
|
||||
```python
|
||||
ypred = bst.predict(xgmat,ntree_limit=bst.best_iteration)
|
||||
```
|
||||
|
||||
Plotting
|
||||
--------
|
||||
|
||||
You can use plotting module to plot importance and output tree.
|
||||
|
||||
To plot importance, use ``plot_importance``. This function requires ``matplotlib`` to be installed.
|
||||
|
||||
```python
|
||||
xgb.plot_importance(bst)
|
||||
```
|
||||
|
||||
To output tree via ``matplotlib``, use ``plot_tree`` specifying ordinal number of the target tree.
|
||||
This function requires ``graphviz`` and ``matplotlib``.
|
||||
|
||||
```python
|
||||
xgb.plot_tree(bst, num_trees=2)
|
||||
```
|
||||
|
||||
When you use ``IPython``, you can use ``to_graphviz`` function which converts the target tree to ``graphviz`` instance. ``graphviz`` instance is automatically rendered on ``IPython``.
|
||||
|
||||
```python
|
||||
xgb.to_graphviz(bst, num_trees=2)
|
||||
```
|
||||
@ -8,9 +8,11 @@ from __future__ import absolute_import
|
||||
from .core import DMatrix, Booster
|
||||
from .training import train, cv
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
|
||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||
|
||||
__version__ = '0.4'
|
||||
|
||||
__all__ = ['DMatrix', 'Booster',
|
||||
'train', 'cv',
|
||||
'XGBModel', 'XGBClassifier', 'XGBRegressor']
|
||||
'XGBModel', 'XGBClassifier', 'XGBRegressor',
|
||||
'plot_importance', 'plot_tree', 'to_graphviz']
|
||||
|
||||
227
python-package/xgboost/plotting.py
Normal file
227
python-package/xgboost/plotting.py
Normal file
@ -0,0 +1,227 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name,
|
||||
# pylint: disable=too-many-branches
|
||||
"""Plotting Library."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
from .core import Booster
|
||||
|
||||
try:
|
||||
from StringIO import StringIO
|
||||
except ImportError:
|
||||
from io import StringIO
|
||||
|
||||
|
||||
def plot_importance(booster, ax=None, height=0.2,
|
||||
xlim=None, title='Feature importance',
|
||||
xlabel='F score', ylabel='Features',
|
||||
grid=True, **kwargs):
|
||||
|
||||
"""Plot importance based on fitted trees.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster or dict
|
||||
Booster instance, or dict taken by Booster.get_fscore()
|
||||
ax : matplotlib Axes, default None
|
||||
Target axes instance. If None, new figure and axes will be created.
|
||||
height : float, default 0.2
|
||||
Bar height, passed to ax.barh()
|
||||
xlim : tuple, default None
|
||||
Tuple passed to axes.xlim()
|
||||
title : str, default "Feature importance"
|
||||
Axes title. To disable, pass None.
|
||||
xlabel : str, default "F score"
|
||||
X axis title label. To disable, pass None.
|
||||
ylabel : str, default "Features"
|
||||
Y axis title label. To disable, pass None.
|
||||
kwargs :
|
||||
Other keywords passed to ax.barh()
|
||||
|
||||
Returns
|
||||
-------
|
||||
ax : matplotlib Axes
|
||||
"""
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
raise ImportError('You must install matplotlib to plot importance')
|
||||
|
||||
if isinstance(booster, Booster):
|
||||
importance = booster.get_fscore()
|
||||
elif isinstance(booster, dict):
|
||||
importance = booster
|
||||
else:
|
||||
raise ValueError('tree must be Booster or dict instance')
|
||||
|
||||
if len(importance) == 0:
|
||||
raise ValueError('Booster.get_fscore() results in empty')
|
||||
|
||||
tuples = [(k, importance[k]) for k in importance]
|
||||
tuples = sorted(tuples, key=lambda x: x[1])
|
||||
labels, values = zip(*tuples)
|
||||
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 1)
|
||||
|
||||
ylocs = np.arange(len(values))
|
||||
ax.barh(ylocs, values, align='center', height=height, **kwargs)
|
||||
|
||||
for x, y in zip(values, ylocs):
|
||||
ax.text(x + 1, y, x, va='center')
|
||||
|
||||
ax.set_yticks(ylocs)
|
||||
ax.set_yticklabels(labels)
|
||||
|
||||
if xlim is not None:
|
||||
if not isinstance(xlim, tuple) or len(xlim, 2):
|
||||
raise ValueError('xlim must be a tuple of 2 elements')
|
||||
else:
|
||||
xlim = (0, max(values) * 1.1)
|
||||
ax.set_xlim(xlim)
|
||||
|
||||
if title is not None:
|
||||
ax.set_title(title)
|
||||
if xlabel is not None:
|
||||
ax.set_xlabel(xlabel)
|
||||
if ylabel is not None:
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.grid(grid)
|
||||
return ax
|
||||
|
||||
|
||||
_NODEPAT = re.compile(r'(\d+):\[(.+)\]')
|
||||
_LEAFPAT = re.compile(r'(\d+):(leaf=.+)')
|
||||
_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
|
||||
|
||||
|
||||
def _parse_node(graph, text):
|
||||
"""parse dumped node"""
|
||||
match = _NODEPAT.match(text)
|
||||
if match is not None:
|
||||
node = match.group(1)
|
||||
graph.node(node, label=match.group(2), shape='circle')
|
||||
return node
|
||||
match = _LEAFPAT.match(text)
|
||||
if match is not None:
|
||||
node = match.group(1)
|
||||
graph.node(node, label=match.group(2), shape='box')
|
||||
return node
|
||||
raise ValueError('Unable to parse node: {0}'.format(text))
|
||||
|
||||
|
||||
def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
|
||||
"""parse dumped edge"""
|
||||
match = _EDGEPAT.match(text)
|
||||
if match is not None:
|
||||
yes, no, missing = match.groups()
|
||||
if yes == missing:
|
||||
graph.edge(node, yes, label='yes, missing', color=yes_color)
|
||||
graph.edge(node, no, label='no', color=no_color)
|
||||
else:
|
||||
graph.edge(node, yes, label='yes', color=yes_color)
|
||||
graph.edge(node, no, label='no, missing', color=no_color)
|
||||
return
|
||||
raise ValueError('Unable to parse edge: {0}'.format(text))
|
||||
|
||||
|
||||
def to_graphviz(booster, num_trees=0, rankdir='UT',
|
||||
yes_color='#0000FF', no_color='#FF0000', **kwargs):
|
||||
|
||||
"""Convert specified tree to graphviz instance. IPython can automatically plot the
|
||||
returned graphiz instance. Otherwise, you shoud call .render() method
|
||||
of the returned graphiz instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster
|
||||
Booster instance
|
||||
num_trees : int, default 0
|
||||
Specify the ordinal number of target tree
|
||||
rankdir : str, default "UT"
|
||||
Passed to graphiz via graph_attr
|
||||
yes_color : str, default '#0000FF'
|
||||
Edge color when meets the node condigion.
|
||||
no_color : str, default '#FF0000'
|
||||
Edge color when doesn't meet the node condigion.
|
||||
kwargs :
|
||||
Other keywords passed to graphviz graph_attr
|
||||
|
||||
Returns
|
||||
-------
|
||||
ax : matplotlib Axes
|
||||
"""
|
||||
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
raise ImportError('You must install graphviz to plot tree')
|
||||
|
||||
if not isinstance(booster, Booster):
|
||||
raise ValueError('booster must be Booster instance')
|
||||
|
||||
tree = booster.get_dump()[num_trees]
|
||||
tree = tree.split()
|
||||
|
||||
kwargs = kwargs.copy()
|
||||
kwargs.update({'rankdir': rankdir})
|
||||
graph = Digraph(graph_attr=kwargs)
|
||||
|
||||
for i, text in enumerate(tree):
|
||||
if text[0].isdigit():
|
||||
node = _parse_node(graph, text)
|
||||
else:
|
||||
if i == 0:
|
||||
# 1st string must be node
|
||||
raise ValueError('Unable to parse given string as tree')
|
||||
_parse_edge(graph, node, text, yes_color=yes_color,
|
||||
no_color=no_color)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
"""Plot specified tree.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster
|
||||
Booster instance
|
||||
num_trees : int, default 0
|
||||
Specify the ordinal number of target tree
|
||||
rankdir : str, default "UT"
|
||||
Passed to graphiz via graph_attr
|
||||
ax : matplotlib Axes, default None
|
||||
Target axes instance. If None, new figure and axes will be created.
|
||||
kwargs :
|
||||
Other keywords passed to to_graphviz
|
||||
|
||||
Returns
|
||||
-------
|
||||
ax : matplotlib Axes
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.image as image
|
||||
except ImportError:
|
||||
raise ImportError('You must install matplotlib to plot tree')
|
||||
|
||||
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 1)
|
||||
|
||||
g = to_graphviz(booster, num_trees=num_trees, rankdir=rankdir, **kwargs)
|
||||
|
||||
s = StringIO()
|
||||
s.write(g.pipe(format='png'))
|
||||
s.seek(0)
|
||||
img = image.imread(s)
|
||||
|
||||
ax.imshow(img)
|
||||
ax.axis('off')
|
||||
return ax
|
||||
@ -7,7 +7,7 @@ fi
|
||||
brew update
|
||||
|
||||
if [ ${TASK} == "python-package" ]; then
|
||||
brew install python git
|
||||
brew install python git graphviz
|
||||
easy_install pip
|
||||
pip install numpy scipy nose
|
||||
fi
|
||||
|
||||
@ -34,6 +34,8 @@ if [ ${TASK} == "R-package" ]; then
|
||||
fi
|
||||
|
||||
if [ ${TASK} == "python-package" ]; then
|
||||
sudo apt-get install graphviz
|
||||
sudo pip install matplotlib graphviz
|
||||
make all CXX=${CXX} || exit -1
|
||||
nosetests tests/python || exit -1
|
||||
fi
|
||||
|
||||
@ -29,3 +29,44 @@ def test_basic():
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2-preds)) == 0
|
||||
|
||||
def test_plotting():
|
||||
bst2 = xgb.Booster(model_file='xgb.model')
|
||||
# plotting
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from graphviz import Digraph
|
||||
|
||||
ax = xgb.plot_importance(bst2)
|
||||
assert isinstance(ax, Axes)
|
||||
assert ax.get_title() == 'Feature importance'
|
||||
assert ax.get_xlabel() == 'F score'
|
||||
assert ax.get_ylabel() == 'Features'
|
||||
assert len(ax.patches) == 4
|
||||
|
||||
ax = xgb.plot_importance(bst2, color='r',
|
||||
title='t', xlabel='x', ylabel='y')
|
||||
assert isinstance(ax, Axes)
|
||||
assert ax.get_title() == 't'
|
||||
assert ax.get_xlabel() == 'x'
|
||||
assert ax.get_ylabel() == 'y'
|
||||
assert len(ax.patches) == 4
|
||||
for p in ax.patches:
|
||||
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
|
||||
|
||||
|
||||
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
|
||||
title=None, xlabel=None, ylabel=None)
|
||||
assert isinstance(ax, Axes)
|
||||
assert ax.get_title() == ''
|
||||
assert ax.get_xlabel() == ''
|
||||
assert ax.get_ylabel() == ''
|
||||
assert len(ax.patches) == 4
|
||||
assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red
|
||||
assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red
|
||||
assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue
|
||||
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
|
||||
|
||||
g = xgb.to_graphviz(bst2, num_trees=0)
|
||||
assert isinstance(g, Digraph)
|
||||
ax = xgb.plot_tree(bst2, num_trees=0)
|
||||
assert isinstance(ax, Axes)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user