Allow plot function to handle XGBModel
This commit is contained in:
@@ -220,7 +220,6 @@ class TestBasic(unittest.TestCase):
|
||||
for p in ax.patches:
|
||||
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
|
||||
|
||||
|
||||
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
|
||||
title=None, xlabel=None, ylabel=None)
|
||||
assert isinstance(ax, Axes)
|
||||
@@ -235,5 +234,50 @@ class TestBasic(unittest.TestCase):
|
||||
|
||||
g = xgb.to_graphviz(bst2, num_trees=0)
|
||||
assert isinstance(g, Digraph)
|
||||
|
||||
ax = xgb.plot_tree(bst2, num_trees=0)
|
||||
assert isinstance(ax, Axes)
|
||||
|
||||
def test_sklearn_api(self):
|
||||
from sklearn import datasets
|
||||
from sklearn.cross_validation import train_test_split
|
||||
|
||||
np.random.seed(1)
|
||||
|
||||
iris = datasets.load_iris()
|
||||
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
|
||||
|
||||
classifier = xgb.XGBClassifier()
|
||||
classifier.fit(tr_d, tr_l)
|
||||
|
||||
preds = classifier.predict(te_d)
|
||||
labels = te_l
|
||||
err = sum([1 for p, l in zip(preds, labels) if p != l]) / len(te_l)
|
||||
# error must be smaller than 10%
|
||||
assert err < 0.1
|
||||
|
||||
def test_sklearn_plotting(self):
|
||||
from sklearn import datasets
|
||||
iris = datasets.load_iris()
|
||||
|
||||
classifier = xgb.XGBClassifier()
|
||||
classifier.fit(iris.data, iris.target)
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from graphviz import Digraph
|
||||
|
||||
ax = xgb.plot_importance(classifier)
|
||||
assert isinstance(ax, Axes)
|
||||
assert ax.get_title() == 'Feature importance'
|
||||
assert ax.get_xlabel() == 'F score'
|
||||
assert ax.get_ylabel() == 'Features'
|
||||
assert len(ax.patches) == 4
|
||||
|
||||
g = xgb.to_graphviz(classifier, num_trees=0)
|
||||
assert isinstance(g, Digraph)
|
||||
|
||||
ax = xgb.plot_tree(classifier, num_trees=0)
|
||||
assert isinstance(ax, Axes)
|
||||
Reference in New Issue
Block a user