ENH: Add visualization to python package

This commit is contained in:
sinhrks
2015-08-11 16:40:09 +09:00
parent a7202ee804
commit d24b36adf9
9 changed files with 311 additions and 2 deletions

View File

@@ -29,3 +29,44 @@ def test_basic():
# assert they are the same
assert np.sum(np.abs(preds2-preds)) == 0
def test_plotting():
bst2 = xgb.Booster(model_file='xgb.model')
# plotting
from matplotlib.axes import Axes
from graphviz import Digraph
ax = xgb.plot_importance(bst2)
assert isinstance(ax, Axes)
assert ax.get_title() == 'Feature importance'
assert ax.get_xlabel() == 'F score'
assert ax.get_ylabel() == 'Features'
assert len(ax.patches) == 4
ax = xgb.plot_importance(bst2, color='r',
title='t', xlabel='x', ylabel='y')
assert isinstance(ax, Axes)
assert ax.get_title() == 't'
assert ax.get_xlabel() == 'x'
assert ax.get_ylabel() == 'y'
assert len(ax.patches) == 4
for p in ax.patches:
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
title=None, xlabel=None, ylabel=None)
assert isinstance(ax, Axes)
assert ax.get_title() == ''
assert ax.get_xlabel() == ''
assert ax.get_ylabel() == ''
assert len(ax.patches) == 4
assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Digraph)
ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes)