Allow plot function to handle XGBModel
This commit is contained in:
parent
eee3046624
commit
6f046327ac
@ -7,6 +7,7 @@ from __future__ import absolute_import
|
|||||||
import re
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster
|
from .core import Booster
|
||||||
|
from .sklearn import XGBModel
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@ -19,8 +20,8 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
booster : Booster or dict
|
booster : Booster, XGBModel or dict
|
||||||
Booster instance, or dict taken by Booster.get_fscore()
|
Booster or XGBModel instance, or dict taken by Booster.get_fscore()
|
||||||
ax : matplotlib Axes, default None
|
ax : matplotlib Axes, default None
|
||||||
Target axes instance. If None, new figure and axes will be created.
|
Target axes instance. If None, new figure and axes will be created.
|
||||||
height : float, default 0.2
|
height : float, default 0.2
|
||||||
@ -46,12 +47,14 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('You must install matplotlib to plot importance')
|
raise ImportError('You must install matplotlib to plot importance')
|
||||||
|
|
||||||
if isinstance(booster, Booster):
|
if isinstance(booster, XGBModel):
|
||||||
|
importance = booster.booster().get_fscore()
|
||||||
|
elif isinstance(booster, Booster):
|
||||||
importance = booster.get_fscore()
|
importance = booster.get_fscore()
|
||||||
elif isinstance(booster, dict):
|
elif isinstance(booster, dict):
|
||||||
importance = booster
|
importance = booster
|
||||||
else:
|
else:
|
||||||
raise ValueError('tree must be Booster or dict instance')
|
raise ValueError('tree must be Booster, XGBModel or dict instance')
|
||||||
|
|
||||||
if len(importance) == 0:
|
if len(importance) == 0:
|
||||||
raise ValueError('Booster.get_fscore() results in empty')
|
raise ValueError('Booster.get_fscore() results in empty')
|
||||||
@ -142,8 +145,8 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
booster : Booster
|
booster : Booster, XGBModel
|
||||||
Booster instance
|
Booster or XGBModel instance
|
||||||
num_trees : int, default 0
|
num_trees : int, default 0
|
||||||
Specify the ordinal number of target tree
|
Specify the ordinal number of target tree
|
||||||
rankdir : str, default "UT"
|
rankdir : str, default "UT"
|
||||||
@ -165,8 +168,11 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('You must install graphviz to plot tree')
|
raise ImportError('You must install graphviz to plot tree')
|
||||||
|
|
||||||
if not isinstance(booster, Booster):
|
if not isinstance(booster, (Booster, XGBModel)):
|
||||||
raise ValueError('booster must be Booster instance')
|
raise ValueError('booster must be Booster or XGBModel instance')
|
||||||
|
|
||||||
|
if isinstance(booster, XGBModel):
|
||||||
|
booster = booster.booster()
|
||||||
|
|
||||||
tree = booster.get_dump()[num_trees]
|
tree = booster.get_dump()[num_trees]
|
||||||
tree = tree.split()
|
tree = tree.split()
|
||||||
@ -193,8 +199,8 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
booster : Booster
|
booster : Booster, XGBModel
|
||||||
Booster instance
|
Booster or XGBModel instance
|
||||||
num_trees : int, default 0
|
num_trees : int, default 0
|
||||||
Specify the ordinal number of target tree
|
Specify the ordinal number of target tree
|
||||||
rankdir : str, default "UT"
|
rankdir : str, default "UT"
|
||||||
@ -216,7 +222,6 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('You must install matplotlib to plot tree')
|
raise ImportError('You must install matplotlib to plot tree')
|
||||||
|
|
||||||
|
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(1, 1)
|
_, ax = plt.subplots(1, 1)
|
||||||
|
|
||||||
|
|||||||
@ -64,7 +64,7 @@ if [ ${TASK} == "python-package" -o ${TASK} == "python-package3" ]; then
|
|||||||
conda create -n myenv python=2.7
|
conda create -n myenv python=2.7
|
||||||
fi
|
fi
|
||||||
source activate myenv
|
source activate myenv
|
||||||
conda install numpy scipy pandas matplotlib nose
|
conda install numpy scipy pandas matplotlib nose scikit-learn
|
||||||
python -m pip install graphviz
|
python -m pip install graphviz
|
||||||
|
|
||||||
make all CXX=${CXX} || exit -1
|
make all CXX=${CXX} || exit -1
|
||||||
|
|||||||
@ -220,7 +220,6 @@ class TestBasic(unittest.TestCase):
|
|||||||
for p in ax.patches:
|
for p in ax.patches:
|
||||||
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
|
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
|
||||||
|
|
||||||
|
|
||||||
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
|
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
|
||||||
title=None, xlabel=None, ylabel=None)
|
title=None, xlabel=None, ylabel=None)
|
||||||
assert isinstance(ax, Axes)
|
assert isinstance(ax, Axes)
|
||||||
@ -235,5 +234,50 @@ class TestBasic(unittest.TestCase):
|
|||||||
|
|
||||||
g = xgb.to_graphviz(bst2, num_trees=0)
|
g = xgb.to_graphviz(bst2, num_trees=0)
|
||||||
assert isinstance(g, Digraph)
|
assert isinstance(g, Digraph)
|
||||||
|
|
||||||
ax = xgb.plot_tree(bst2, num_trees=0)
|
ax = xgb.plot_tree(bst2, num_trees=0)
|
||||||
assert isinstance(ax, Axes)
|
assert isinstance(ax, Axes)
|
||||||
|
|
||||||
|
def test_sklearn_api(self):
|
||||||
|
from sklearn import datasets
|
||||||
|
from sklearn.cross_validation import train_test_split
|
||||||
|
|
||||||
|
np.random.seed(1)
|
||||||
|
|
||||||
|
iris = datasets.load_iris()
|
||||||
|
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
|
||||||
|
|
||||||
|
classifier = xgb.XGBClassifier()
|
||||||
|
classifier.fit(tr_d, tr_l)
|
||||||
|
|
||||||
|
preds = classifier.predict(te_d)
|
||||||
|
labels = te_l
|
||||||
|
err = sum([1 for p, l in zip(preds, labels) if p != l]) / len(te_l)
|
||||||
|
# error must be smaller than 10%
|
||||||
|
assert err < 0.1
|
||||||
|
|
||||||
|
def test_sklearn_plotting(self):
|
||||||
|
from sklearn import datasets
|
||||||
|
iris = datasets.load_iris()
|
||||||
|
|
||||||
|
classifier = xgb.XGBClassifier()
|
||||||
|
classifier.fit(iris.data, iris.target)
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
|
||||||
|
from matplotlib.axes import Axes
|
||||||
|
from graphviz import Digraph
|
||||||
|
|
||||||
|
ax = xgb.plot_importance(classifier)
|
||||||
|
assert isinstance(ax, Axes)
|
||||||
|
assert ax.get_title() == 'Feature importance'
|
||||||
|
assert ax.get_xlabel() == 'F score'
|
||||||
|
assert ax.get_ylabel() == 'Features'
|
||||||
|
assert len(ax.patches) == 4
|
||||||
|
|
||||||
|
g = xgb.to_graphviz(classifier, num_trees=0)
|
||||||
|
assert isinstance(g, Digraph)
|
||||||
|
|
||||||
|
ax = xgb.plot_tree(classifier, num_trees=0)
|
||||||
|
assert isinstance(ax, Axes)
|
||||||
Loading…
x
Reference in New Issue
Block a user