diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index 50a844a1e..97c4cc2f5 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -7,6 +7,7 @@ from __future__ import absolute_import import re import numpy as np from .core import Booster +from .sklearn import XGBModel from io import BytesIO @@ -19,8 +20,8 @@ def plot_importance(booster, ax=None, height=0.2, Parameters ---------- - booster : Booster or dict - Booster instance, or dict taken by Booster.get_fscore() + booster : Booster, XGBModel or dict + Booster or XGBModel instance, or dict taken by Booster.get_fscore() ax : matplotlib Axes, default None Target axes instance. If None, new figure and axes will be created. height : float, default 0.2 @@ -46,12 +47,14 @@ def plot_importance(booster, ax=None, height=0.2, except ImportError: raise ImportError('You must install matplotlib to plot importance') - if isinstance(booster, Booster): + if isinstance(booster, XGBModel): + importance = booster.booster().get_fscore() + elif isinstance(booster, Booster): importance = booster.get_fscore() elif isinstance(booster, dict): importance = booster else: - raise ValueError('tree must be Booster or dict instance') + raise ValueError('tree must be Booster, XGBModel or dict instance') if len(importance) == 0: raise ValueError('Booster.get_fscore() results in empty') @@ -142,8 +145,8 @@ def to_graphviz(booster, num_trees=0, rankdir='UT', Parameters ---------- - booster : Booster - Booster instance + booster : Booster, XGBModel + Booster or XGBModel instance num_trees : int, default 0 Specify the ordinal number of target tree rankdir : str, default "UT" @@ -165,8 +168,11 @@ def to_graphviz(booster, num_trees=0, rankdir='UT', except ImportError: raise ImportError('You must install graphviz to plot tree') - if not isinstance(booster, Booster): - raise ValueError('booster must be Booster instance') + if not isinstance(booster, (Booster, XGBModel)): + raise ValueError('booster must be Booster or XGBModel instance') + + if isinstance(booster, XGBModel): + booster = booster.booster() tree = booster.get_dump()[num_trees] tree = tree.split() @@ -193,8 +199,8 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs): Parameters ---------- - booster : Booster - Booster instance + booster : Booster, XGBModel + Booster or XGBModel instance num_trees : int, default 0 Specify the ordinal number of target tree rankdir : str, default "UT" @@ -216,7 +222,6 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs): except ImportError: raise ImportError('You must install matplotlib to plot tree') - if ax is None: _, ax = plt.subplots(1, 1) diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index 3a026966d..1e62b5b46 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -64,7 +64,7 @@ if [ ${TASK} == "python-package" -o ${TASK} == "python-package3" ]; then conda create -n myenv python=2.7 fi source activate myenv - conda install numpy scipy pandas matplotlib nose + conda install numpy scipy pandas matplotlib nose scikit-learn python -m pip install graphviz make all CXX=${CXX} || exit -1 diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index fa287b247..710af8e4c 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -220,7 +220,6 @@ class TestBasic(unittest.TestCase): for p in ax.patches: assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red - ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'], title=None, xlabel=None, ylabel=None) assert isinstance(ax, Axes) @@ -235,5 +234,50 @@ class TestBasic(unittest.TestCase): g = xgb.to_graphviz(bst2, num_trees=0) assert isinstance(g, Digraph) + ax = xgb.plot_tree(bst2, num_trees=0) assert isinstance(ax, Axes) + + def test_sklearn_api(self): + from sklearn import datasets + from sklearn.cross_validation import train_test_split + + np.random.seed(1) + + iris = datasets.load_iris() + tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120) + + classifier = xgb.XGBClassifier() + classifier.fit(tr_d, tr_l) + + preds = classifier.predict(te_d) + labels = te_l + err = sum([1 for p, l in zip(preds, labels) if p != l]) / len(te_l) + # error must be smaller than 10% + assert err < 0.1 + + def test_sklearn_plotting(self): + from sklearn import datasets + iris = datasets.load_iris() + + classifier = xgb.XGBClassifier() + classifier.fit(iris.data, iris.target) + + import matplotlib + matplotlib.use('Agg') + + from matplotlib.axes import Axes + from graphviz import Digraph + + ax = xgb.plot_importance(classifier) + assert isinstance(ax, Axes) + assert ax.get_title() == 'Feature importance' + assert ax.get_xlabel() == 'F score' + assert ax.get_ylabel() == 'Features' + assert len(ax.patches) == 4 + + g = xgb.to_graphviz(classifier, num_trees=0) + assert isinstance(g, Digraph) + + ax = xgb.plot_tree(classifier, num_trees=0) + assert isinstance(ax, Axes) \ No newline at end of file