diff --git a/CHANGES.md b/CHANGES.md index d9c8786c0..4484a321b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -40,6 +40,7 @@ on going at master * Fix List - Fixed possible problem of poisson regression for R. * Python module now throw exception instead of crash terminal when a parameter error happens. +* Python module now has importance plot and tree plot functions. * Java api is ready for use * Added more test cases and continuous integration to make each build more robust * Improvements in sklearn compatible module \ No newline at end of file diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 36ccc9d5d..ad6c01f2f 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -44,3 +44,5 @@ List of Contributors * [Jamie Hall](https://github.com/nerdcha) - Jamie is the initial creator of xgboost sklearn modue. * [Yen-Ying Lee](https://github.com/white1033) +* [Masaaki Horikoshi](https://github.com/sinhrks) + - Masaaki is the initial creator of xgboost python plotting module. \ No newline at end of file diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index 85249cbc4..1374e4bfc 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -35,3 +35,13 @@ Scikit-Learn API .. autoclass:: xgboost.XGBClassifier :members: :show-inheritance: + +Plotting API +------------ +.. automodule:: xgboost.plotting + +.. autofunction:: xgboost.plot_importance + +.. autofunction:: xgboost.plot_tree + +.. autofunction:: xgboost.to_graphviz diff --git a/doc/python/python_intro.md b/doc/python/python_intro.md index 2b670a053..b46358877 100644 --- a/doc/python/python_intro.md +++ b/doc/python/python_intro.md @@ -127,3 +127,27 @@ If early stopping is enabled during training, you can predict with the best iter ```python ypred = bst.predict(xgmat,ntree_limit=bst.best_iteration) ``` + +Plotting +-------- + +You can use plotting module to plot importance and output tree. + +To plot importance, use ``plot_importance``. This function requires ``matplotlib`` to be installed. + +```python +xgb.plot_importance(bst) +``` + +To output tree via ``matplotlib``, use ``plot_tree`` specifying ordinal number of the target tree. +This function requires ``graphviz`` and ``matplotlib``. + +```python +xgb.plot_tree(bst, num_trees=2) +``` + +When you use ``IPython``, you can use ``to_graphviz`` function which converts the target tree to ``graphviz`` instance. ``graphviz`` instance is automatically rendered on ``IPython``. + +```python +xgb.to_graphviz(bst, num_trees=2) +``` \ No newline at end of file diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index b284c27e0..b251b4501 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -8,9 +8,11 @@ from __future__ import absolute_import from .core import DMatrix, Booster from .training import train, cv from .sklearn import XGBModel, XGBClassifier, XGBRegressor +from .plotting import plot_importance, plot_tree, to_graphviz __version__ = '0.4' __all__ = ['DMatrix', 'Booster', 'train', 'cv', - 'XGBModel', 'XGBClassifier', 'XGBRegressor'] + 'XGBModel', 'XGBClassifier', 'XGBRegressor', + 'plot_importance', 'plot_tree', 'to_graphviz'] diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py new file mode 100644 index 000000000..7c34a11f1 --- /dev/null +++ b/python-package/xgboost/plotting.py @@ -0,0 +1,227 @@ +# coding: utf-8 +# pylint: disable=too-many-locals, too-many-arguments, invalid-name, +# pylint: disable=too-many-branches +"""Plotting Library.""" +from __future__ import absolute_import + +import re +import numpy as np +from .core import Booster + +try: + from StringIO import StringIO +except ImportError: + from io import StringIO + + +def plot_importance(booster, ax=None, height=0.2, + xlim=None, title='Feature importance', + xlabel='F score', ylabel='Features', + grid=True, **kwargs): + + """Plot importance based on fitted trees. + + Parameters + ---------- + booster : Booster or dict + Booster instance, or dict taken by Booster.get_fscore() + ax : matplotlib Axes, default None + Target axes instance. If None, new figure and axes will be created. + height : float, default 0.2 + Bar height, passed to ax.barh() + xlim : tuple, default None + Tuple passed to axes.xlim() + title : str, default "Feature importance" + Axes title. To disable, pass None. + xlabel : str, default "F score" + X axis title label. To disable, pass None. + ylabel : str, default "Features" + Y axis title label. To disable, pass None. + kwargs : + Other keywords passed to ax.barh() + + Returns + ------- + ax : matplotlib Axes + """ + + try: + import matplotlib.pyplot as plt + except ImportError: + raise ImportError('You must install matplotlib to plot importance') + + if isinstance(booster, Booster): + importance = booster.get_fscore() + elif isinstance(booster, dict): + importance = booster + else: + raise ValueError('tree must be Booster or dict instance') + + if len(importance) == 0: + raise ValueError('Booster.get_fscore() results in empty') + + tuples = [(k, importance[k]) for k in importance] + tuples = sorted(tuples, key=lambda x: x[1]) + labels, values = zip(*tuples) + + if ax is None: + _, ax = plt.subplots(1, 1) + + ylocs = np.arange(len(values)) + ax.barh(ylocs, values, align='center', height=height, **kwargs) + + for x, y in zip(values, ylocs): + ax.text(x + 1, y, x, va='center') + + ax.set_yticks(ylocs) + ax.set_yticklabels(labels) + + if xlim is not None: + if not isinstance(xlim, tuple) or len(xlim, 2): + raise ValueError('xlim must be a tuple of 2 elements') + else: + xlim = (0, max(values) * 1.1) + ax.set_xlim(xlim) + + if title is not None: + ax.set_title(title) + if xlabel is not None: + ax.set_xlabel(xlabel) + if ylabel is not None: + ax.set_ylabel(ylabel) + ax.grid(grid) + return ax + + +_NODEPAT = re.compile(r'(\d+):\[(.+)\]') +_LEAFPAT = re.compile(r'(\d+):(leaf=.+)') +_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)') + + +def _parse_node(graph, text): + """parse dumped node""" + match = _NODEPAT.match(text) + if match is not None: + node = match.group(1) + graph.node(node, label=match.group(2), shape='circle') + return node + match = _LEAFPAT.match(text) + if match is not None: + node = match.group(1) + graph.node(node, label=match.group(2), shape='box') + return node + raise ValueError('Unable to parse node: {0}'.format(text)) + + +def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'): + """parse dumped edge""" + match = _EDGEPAT.match(text) + if match is not None: + yes, no, missing = match.groups() + if yes == missing: + graph.edge(node, yes, label='yes, missing', color=yes_color) + graph.edge(node, no, label='no', color=no_color) + else: + graph.edge(node, yes, label='yes', color=yes_color) + graph.edge(node, no, label='no, missing', color=no_color) + return + raise ValueError('Unable to parse edge: {0}'.format(text)) + + +def to_graphviz(booster, num_trees=0, rankdir='UT', + yes_color='#0000FF', no_color='#FF0000', **kwargs): + + """Convert specified tree to graphviz instance. IPython can automatically plot the + returned graphiz instance. Otherwise, you shoud call .render() method + of the returned graphiz instance. + + Parameters + ---------- + booster : Booster + Booster instance + num_trees : int, default 0 + Specify the ordinal number of target tree + rankdir : str, default "UT" + Passed to graphiz via graph_attr + yes_color : str, default '#0000FF' + Edge color when meets the node condigion. + no_color : str, default '#FF0000' + Edge color when doesn't meet the node condigion. + kwargs : + Other keywords passed to graphviz graph_attr + + Returns + ------- + ax : matplotlib Axes + """ + + try: + from graphviz import Digraph + except ImportError: + raise ImportError('You must install graphviz to plot tree') + + if not isinstance(booster, Booster): + raise ValueError('booster must be Booster instance') + + tree = booster.get_dump()[num_trees] + tree = tree.split() + + kwargs = kwargs.copy() + kwargs.update({'rankdir': rankdir}) + graph = Digraph(graph_attr=kwargs) + + for i, text in enumerate(tree): + if text[0].isdigit(): + node = _parse_node(graph, text) + else: + if i == 0: + # 1st string must be node + raise ValueError('Unable to parse given string as tree') + _parse_edge(graph, node, text, yes_color=yes_color, + no_color=no_color) + + return graph + + +def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs): + """Plot specified tree. + + Parameters + ---------- + booster : Booster + Booster instance + num_trees : int, default 0 + Specify the ordinal number of target tree + rankdir : str, default "UT" + Passed to graphiz via graph_attr + ax : matplotlib Axes, default None + Target axes instance. If None, new figure and axes will be created. + kwargs : + Other keywords passed to to_graphviz + + Returns + ------- + ax : matplotlib Axes + + """ + + try: + import matplotlib.pyplot as plt + import matplotlib.image as image + except ImportError: + raise ImportError('You must install matplotlib to plot tree') + + + if ax is None: + _, ax = plt.subplots(1, 1) + + g = to_graphviz(booster, num_trees=num_trees, rankdir=rankdir, **kwargs) + + s = StringIO() + s.write(g.pipe(format='png')) + s.seek(0) + img = image.imread(s) + + ax.imshow(img) + ax.axis('off') + return ax diff --git a/scripts/travis_osx_install.sh b/scripts/travis_osx_install.sh index 8121afd6b..9f3f6e831 100755 --- a/scripts/travis_osx_install.sh +++ b/scripts/travis_osx_install.sh @@ -7,7 +7,7 @@ fi brew update if [ ${TASK} == "python-package" ]; then - brew install python git + brew install python git graphviz easy_install pip pip install numpy scipy nose fi diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index 402cb6992..d58eaa5de 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -34,6 +34,8 @@ if [ ${TASK} == "R-package" ]; then fi if [ ${TASK} == "python-package" ]; then + sudo apt-get install graphviz + sudo pip install matplotlib graphviz make all CXX=${CXX} || exit -1 nosetests tests/python || exit -1 fi diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 77d19595b..2ed1cc462 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -29,3 +29,44 @@ def test_basic(): # assert they are the same assert np.sum(np.abs(preds2-preds)) == 0 +def test_plotting(): + bst2 = xgb.Booster(model_file='xgb.model') + # plotting + + from matplotlib.axes import Axes + from graphviz import Digraph + + ax = xgb.plot_importance(bst2) + assert isinstance(ax, Axes) + assert ax.get_title() == 'Feature importance' + assert ax.get_xlabel() == 'F score' + assert ax.get_ylabel() == 'Features' + assert len(ax.patches) == 4 + + ax = xgb.plot_importance(bst2, color='r', + title='t', xlabel='x', ylabel='y') + assert isinstance(ax, Axes) + assert ax.get_title() == 't' + assert ax.get_xlabel() == 'x' + assert ax.get_ylabel() == 'y' + assert len(ax.patches) == 4 + for p in ax.patches: + assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red + + + ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'], + title=None, xlabel=None, ylabel=None) + assert isinstance(ax, Axes) + assert ax.get_title() == '' + assert ax.get_xlabel() == '' + assert ax.get_ylabel() == '' + assert len(ax.patches) == 4 + assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red + assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red + assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue + assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue + + g = xgb.to_graphviz(bst2, num_trees=0) + assert isinstance(g, Digraph) + ax = xgb.plot_tree(bst2, num_trees=0) + assert isinstance(ax, Axes)