Allow plot function to handle XGBModel
This commit is contained in:
@@ -7,6 +7,7 @@ from __future__ import absolute_import
|
||||
import re
|
||||
import numpy as np
|
||||
from .core import Booster
|
||||
from .sklearn import XGBModel
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
@@ -19,8 +20,8 @@ def plot_importance(booster, ax=None, height=0.2,
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster or dict
|
||||
Booster instance, or dict taken by Booster.get_fscore()
|
||||
booster : Booster, XGBModel or dict
|
||||
Booster or XGBModel instance, or dict taken by Booster.get_fscore()
|
||||
ax : matplotlib Axes, default None
|
||||
Target axes instance. If None, new figure and axes will be created.
|
||||
height : float, default 0.2
|
||||
@@ -46,12 +47,14 @@ def plot_importance(booster, ax=None, height=0.2,
|
||||
except ImportError:
|
||||
raise ImportError('You must install matplotlib to plot importance')
|
||||
|
||||
if isinstance(booster, Booster):
|
||||
if isinstance(booster, XGBModel):
|
||||
importance = booster.booster().get_fscore()
|
||||
elif isinstance(booster, Booster):
|
||||
importance = booster.get_fscore()
|
||||
elif isinstance(booster, dict):
|
||||
importance = booster
|
||||
else:
|
||||
raise ValueError('tree must be Booster or dict instance')
|
||||
raise ValueError('tree must be Booster, XGBModel or dict instance')
|
||||
|
||||
if len(importance) == 0:
|
||||
raise ValueError('Booster.get_fscore() results in empty')
|
||||
@@ -142,8 +145,8 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster
|
||||
Booster instance
|
||||
booster : Booster, XGBModel
|
||||
Booster or XGBModel instance
|
||||
num_trees : int, default 0
|
||||
Specify the ordinal number of target tree
|
||||
rankdir : str, default "UT"
|
||||
@@ -165,8 +168,11 @@ def to_graphviz(booster, num_trees=0, rankdir='UT',
|
||||
except ImportError:
|
||||
raise ImportError('You must install graphviz to plot tree')
|
||||
|
||||
if not isinstance(booster, Booster):
|
||||
raise ValueError('booster must be Booster instance')
|
||||
if not isinstance(booster, (Booster, XGBModel)):
|
||||
raise ValueError('booster must be Booster or XGBModel instance')
|
||||
|
||||
if isinstance(booster, XGBModel):
|
||||
booster = booster.booster()
|
||||
|
||||
tree = booster.get_dump()[num_trees]
|
||||
tree = tree.split()
|
||||
@@ -193,8 +199,8 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster
|
||||
Booster instance
|
||||
booster : Booster, XGBModel
|
||||
Booster or XGBModel instance
|
||||
num_trees : int, default 0
|
||||
Specify the ordinal number of target tree
|
||||
rankdir : str, default "UT"
|
||||
@@ -216,7 +222,6 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs):
|
||||
except ImportError:
|
||||
raise ImportError('You must install matplotlib to plot tree')
|
||||
|
||||
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user