Merge pull request #555 from sinhrks/plot_sklearn

Allow plot function to handle XGBModel
This commit is contained in:
Tianqi Chen 2015-10-22 08:39:25 -07:00
commit cb7f331ebc
3 changed files with 62 additions and 13 deletions

View File

@ -7,6 +7,7 @@ from __future__ import absolute_import
import re import re
import numpy as np import numpy as np
from .core import Booster from .core import Booster
from .sklearn import XGBModel
from io import BytesIO from io import BytesIO
@ -19,8 +20,8 @@ def plot_importance(booster, ax=None, height=0.2,
Parameters Parameters
---------- ----------
booster : Booster or dict booster : Booster, XGBModel or dict
Booster instance, or dict taken by Booster.get_fscore() Booster or XGBModel instance, or dict taken by Booster.get_fscore()
ax : matplotlib Axes, default None ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created. Target axes instance. If None, new figure and axes will be created.
height : float, default 0.2 height : float, default 0.2
@ -46,12 +47,14 @@ def plot_importance(booster, ax=None, height=0.2,
except ImportError: except ImportError:
raise ImportError('You must install matplotlib to plot importance') raise ImportError('You must install matplotlib to plot importance')
if isinstance(booster, Booster): if isinstance(booster, XGBModel):
importance = booster.booster().get_fscore()
elif isinstance(booster, Booster):
importance = booster.get_fscore() importance = booster.get_fscore()
elif isinstance(booster, dict): elif isinstance(booster, dict):
importance = booster importance = booster
else: else:
raise ValueError('tree must be Booster or dict instance') raise ValueError('tree must be Booster, XGBModel or dict instance')
if len(importance) == 0: if len(importance) == 0:
raise ValueError('Booster.get_fscore() results in empty') raise ValueError('Booster.get_fscore() results in empty')
@ -142,8 +145,8 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
Parameters Parameters
---------- ----------
booster : Booster booster : Booster, XGBModel
Booster instance Booster or XGBModel instance
num_trees : int, default 0 num_trees : int, default 0
Specify the ordinal number of target tree Specify the ordinal number of target tree
rankdir : str, default "UT" rankdir : str, default "UT"
@ -165,8 +168,11 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
except ImportError: except ImportError:
raise ImportError('You must install graphviz to plot tree') raise ImportError('You must install graphviz to plot tree')
if not isinstance(booster, Booster): if not isinstance(booster, (Booster, XGBModel)):
raise ValueError('booster must be Booster instance') raise ValueError('booster must be Booster or XGBModel instance')
if isinstance(booster, XGBModel):
booster = booster.booster()
tree = booster.get_dump()[num_trees] tree = booster.get_dump()[num_trees]
tree = tree.split() tree = tree.split()
@ -193,8 +199,8 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
Parameters Parameters
---------- ----------
booster : Booster booster : Booster, XGBModel
Booster instance Booster or XGBModel instance
num_trees : int, default 0 num_trees : int, default 0
Specify the ordinal number of target tree Specify the ordinal number of target tree
rankdir : str, default "UT" rankdir : str, default "UT"
@ -216,7 +222,6 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
except ImportError: except ImportError:
raise ImportError('You must install matplotlib to plot tree') raise ImportError('You must install matplotlib to plot tree')
if ax is None: if ax is None:
_, ax = plt.subplots(1, 1) _, ax = plt.subplots(1, 1)

View File

@ -64,7 +64,7 @@ if [ ${TASK} == "python-package" -o ${TASK} == "python-package3" ]; then
conda create -n myenv python=2.7 conda create -n myenv python=2.7
fi fi
source activate myenv source activate myenv
conda install numpy scipy pandas matplotlib nose conda install numpy scipy pandas matplotlib nose scikit-learn
python -m pip install graphviz python -m pip install graphviz
make all CXX=${CXX} || exit -1 make all CXX=${CXX} || exit -1

View File

@ -220,7 +220,6 @@ class TestBasic(unittest.TestCase):
for p in ax.patches: for p in ax.patches:
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'], ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
title=None, xlabel=None, ylabel=None) title=None, xlabel=None, ylabel=None)
assert isinstance(ax, Axes) assert isinstance(ax, Axes)
@ -235,5 +234,50 @@ class TestBasic(unittest.TestCase):
g = xgb.to_graphviz(bst2, num_trees=0) g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Digraph) assert isinstance(g, Digraph)
ax = xgb.plot_tree(bst2, num_trees=0) ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes) assert isinstance(ax, Axes)
def test_sklearn_api(self):
from sklearn import datasets
from sklearn.cross_validation import train_test_split
np.random.seed(1)
iris = datasets.load_iris()
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
classifier = xgb.XGBClassifier()
classifier.fit(tr_d, tr_l)
preds = classifier.predict(te_d)
labels = te_l
err = sum([1 for p, l in zip(preds, labels) if p != l]) / len(te_l)
# error must be smaller than 10%
assert err < 0.1
def test_sklearn_plotting(self):
from sklearn import datasets
iris = datasets.load_iris()
classifier = xgb.XGBClassifier()
classifier.fit(iris.data, iris.target)
import matplotlib
matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Digraph
ax = xgb.plot_importance(classifier)
assert isinstance(ax, Axes)
assert ax.get_title() == 'Feature importance'
assert ax.get_xlabel() == 'F score'
assert ax.get_ylabel() == 'Features'
assert len(ax.patches) == 4
g = xgb.to_graphviz(classifier, num_trees=0)
assert isinstance(g, Digraph)
ax = xgb.plot_tree(classifier, num_trees=0)
assert isinstance(ax, Axes)