From 1f19b7828794595684eb9aeb09bcfe5bac167c99 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Sat, 24 Oct 2015 19:15:43 +0900 Subject: [PATCH] Python: adjusts plot_importance ylim --- python-package/xgboost/plotting.py | 13 +++++++++++-- tests/python/test_basic.py | 18 +++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index 97c4cc2f5..f8489a6f8 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -12,7 +12,7 @@ from .sklearn import XGBModel from io import BytesIO def plot_importance(booster, ax=None, height=0.2, - xlim=None, title='Feature importance', + xlim=None, ylim=None, title='Feature importance', xlabel='F score', ylabel='Features', grid=True, **kwargs): @@ -28,6 +28,8 @@ def plot_importance(booster, ax=None, height=0.2, Bar height, passed to ax.barh() xlim : tuple, default None Tuple passed to axes.xlim() + ylim : tuple, default None + Tuple passed to axes.ylim() title : str, default "Feature importance" Axes title. To disable, pass None. xlabel : str, default "F score" @@ -76,12 +78,19 @@ def plot_importance(booster, ax=None, height=0.2, ax.set_yticklabels(labels) if xlim is not None: - if not isinstance(xlim, tuple) or len(xlim, 2): + if not isinstance(xlim, tuple) or len(xlim) != 2: raise ValueError('xlim must be a tuple of 2 elements') else: xlim = (0, max(values) * 1.1) ax.set_xlim(xlim) + if ylim is not None: + if not isinstance(ylim, tuple) or len(ylim) != 2: + raise ValueError('ylim must be a tuple of 2 elements') + else: + ylim = (-1, len(importance)) + ax.set_ylim(ylim) + if title is not None: ax.set_title(title) if xlabel is not None: diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 79288b371..a8e0d5238 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -3,6 +3,8 @@ import numpy as np import xgboost as xgb import unittest +import matplotlib +matplotlib.use('Agg') dpath = 'demo/data/' rng = np.random.RandomState(1994) @@ -198,9 +200,6 @@ class TestBasic(unittest.TestCase): bst2 = xgb.Booster(model_file='xgb.model') # plotting - import matplotlib - matplotlib.use('Agg') - from matplotlib.axes import Axes from graphviz import Digraph @@ -239,6 +238,19 @@ class TestBasic(unittest.TestCase): ax = xgb.plot_tree(bst2, num_trees=0) assert isinstance(ax, Axes) + def test_importance_plot_lim(self): + np.random.seed(1) + dm = xgb.DMatrix(np.random.randn(100, 100), label=[0, 1]*50) + bst = xgb.train({}, dm) + assert len(bst.get_fscore()) == 71 + ax = xgb.plot_importance(bst) + assert ax.get_xlim() == (0., 11.) + assert ax.get_ylim() == (-1., 71.) + + ax = xgb.plot_importance(bst, xlim=(0, 5), ylim=(10, 71)) + assert ax.get_xlim() == (0., 5.) + assert ax.get_ylim() == (10., 71.) + def test_sklearn_api(self): from sklearn import datasets from sklearn.cross_validation import train_test_split