Python: adjusts plot_importance ylim

This commit is contained in:
sinhrks
2015-10-24 19:15:43 +09:00
parent 36927632c5
commit 1f19b78287
2 changed files with 26 additions and 5 deletions

View File

@@ -3,6 +3,8 @@ import numpy as np
import xgboost as xgb
import unittest
import matplotlib
matplotlib.use('Agg')
dpath = 'demo/data/'
rng = np.random.RandomState(1994)
@@ -198,9 +200,6 @@ class TestBasic(unittest.TestCase):
bst2 = xgb.Booster(model_file='xgb.model')
# plotting
import matplotlib
matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Digraph
@@ -239,6 +238,19 @@ class TestBasic(unittest.TestCase):
ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes)
def test_importance_plot_lim(self):
np.random.seed(1)
dm = xgb.DMatrix(np.random.randn(100, 100), label=[0, 1]*50)
bst = xgb.train({}, dm)
assert len(bst.get_fscore()) == 71
ax = xgb.plot_importance(bst)
assert ax.get_xlim() == (0., 11.)
assert ax.get_ylim() == (-1., 71.)
ax = xgb.plot_importance(bst, xlim=(0, 5), ylim=(10, 71))
assert ax.get_xlim() == (0., 5.)
assert ax.get_ylim() == (10., 71.)
def test_sklearn_api(self):
from sklearn import datasets
from sklearn.cross_validation import train_test_split