# pylint: disable=too-many-locals, too-many-arguments, invalid-name, # pylint: disable=too-many-branches """Plotting Library.""" import json from io import BytesIO from typing import Any, Optional, Union import numpy as np from ._typing import PathLike from .core import Booster from .sklearn import XGBModel Axes = Any # real type is matplotlib.axes.Axes GraphvizSource = Any # real type is graphviz.Source def plot_importance( booster: Union[XGBModel, Booster, dict], ax: Optional[Axes] = None, height: float = 0.2, xlim: Optional[tuple] = None, ylim: Optional[tuple] = None, title: str = "Feature importance", xlabel: str = "F score", ylabel: str = "Features", fmap: PathLike = "", importance_type: str = "weight", max_num_features: Optional[int] = None, grid: bool = True, show_values: bool = True, values_format: str = "{v}", **kwargs: Any, ) -> Axes: """Plot importance based on fitted trees. Parameters ---------- booster : Booster or XGBModel instance, or dict taken by Booster.get_fscore() ax : matplotlib Axes Target axes instance. If None, new figure and axes will be created. grid : Turn the axes grids on or off. Default is True (On). importance_type : How the importance is calculated: either "weight", "gain", or "cover" * "weight" is the number of times a feature appears in a tree * "gain" is the average gain of splits which use the feature * "cover" is the average coverage of splits which use the feature where coverage is defined as the number of samples affected by the split max_num_features : Maximum number of top features displayed on plot. If None, all features will be displayed. height : Bar height, passed to ax.barh() xlim : Tuple passed to axes.xlim() ylim : Tuple passed to axes.ylim() title : Axes title. To disable, pass None. xlabel : X axis title label. To disable, pass None. ylabel : Y axis title label. To disable, pass None. fmap : The name of feature map file. show_values : Show values on plot. To disable, pass False. values_format : Format string for values. "v" will be replaced by the value of the feature importance. e.g. Pass "{v:.2f}" in order to limit the number of digits after the decimal point to two, for each value printed on the graph. kwargs : Other keywords passed to ax.barh() Returns ------- ax : matplotlib Axes """ try: import matplotlib.pyplot as plt except ImportError as e: raise ImportError("You must install matplotlib to plot importance") from e if isinstance(booster, XGBModel): importance = booster.get_booster().get_score( importance_type=importance_type, fmap=fmap ) elif isinstance(booster, Booster): importance = booster.get_score(importance_type=importance_type, fmap=fmap) elif isinstance(booster, dict): importance = booster else: raise ValueError("tree must be Booster, XGBModel or dict instance") if not importance: raise ValueError( "Booster.get_score() results in empty. " + "This maybe caused by having all trees as decision dumps." ) tuples = [(k, importance[k]) for k in importance] if max_num_features is not None: # pylint: disable=invalid-unary-operand-type tuples = sorted(tuples, key=lambda _x: _x[1])[-max_num_features:] else: tuples = sorted(tuples, key=lambda _x: _x[1]) labels, values = zip(*tuples) if ax is None: _, ax = plt.subplots(1, 1) ylocs = np.arange(len(values)) ax.barh(ylocs, values, align="center", height=height, **kwargs) if show_values is True: for x, y in zip(values, ylocs): ax.text(x + 1, y, values_format.format(v=x), va="center") ax.set_yticks(ylocs) ax.set_yticklabels(labels) if xlim is not None: if not isinstance(xlim, tuple) or len(xlim) != 2: raise ValueError("xlim must be a tuple of 2 elements") else: xlim = (0, max(values) * 1.1) ax.set_xlim(xlim) if ylim is not None: if not isinstance(ylim, tuple) or len(ylim) != 2: raise ValueError("ylim must be a tuple of 2 elements") else: ylim = (-1, len(values)) ax.set_ylim(ylim) if title is not None: ax.set_title(title) if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) ax.grid(grid) return ax def to_graphviz( booster: Union[Booster, XGBModel], fmap: PathLike = "", num_trees: int = 0, rankdir: Optional[str] = None, yes_color: Optional[str] = None, no_color: Optional[str] = None, condition_node_params: Optional[dict] = None, leaf_node_params: Optional[dict] = None, **kwargs: Any, ) -> GraphvizSource: """Convert specified tree to graphviz instance. IPython can automatically plot the returned graphviz instance. Otherwise, you should call .render() method of the returned graphviz instance. Parameters ---------- booster : Booster or XGBModel instance fmap : The name of feature map file num_trees : Specify the ordinal number of target tree rankdir : Passed to graphviz via graph_attr yes_color : Edge color when meets the node condition. no_color : Edge color when doesn't meet the node condition. condition_node_params : Condition node configuration for for graphviz. Example: .. code-block:: python {'shape': 'box', 'style': 'filled,rounded', 'fillcolor': '#78bceb'} leaf_node_params : Leaf node configuration for graphviz. Example: .. code-block:: python {'shape': 'box', 'style': 'filled', 'fillcolor': '#e48038'} kwargs : Other keywords passed to graphviz graph_attr, e.g. ``graph [ {key} = {value} ]`` Returns ------- graph: graphviz.Source """ try: from graphviz import Source except ImportError as e: raise ImportError("You must install graphviz to plot tree") from e if isinstance(booster, XGBModel): booster = booster.get_booster() # squash everything back into kwargs again for compatibility parameters = "dot" extra = {} for key, value in kwargs.items(): extra[key] = value if rankdir is not None: kwargs["graph_attrs"] = {} kwargs["graph_attrs"]["rankdir"] = rankdir for key, value in extra.items(): if kwargs.get("graph_attrs", None) is not None: kwargs["graph_attrs"][key] = value else: kwargs["graph_attrs"] = {} del kwargs[key] if yes_color is not None or no_color is not None: kwargs["edge"] = {} if yes_color is not None: kwargs["edge"]["yes_color"] = yes_color if no_color is not None: kwargs["edge"]["no_color"] = no_color if condition_node_params is not None: kwargs["condition_node_params"] = condition_node_params if leaf_node_params is not None: kwargs["leaf_node_params"] = leaf_node_params if kwargs: parameters += ":" parameters += json.dumps(kwargs) tree = booster.get_dump(fmap=fmap, dump_format=parameters)[num_trees] g = Source(tree) return g def plot_tree( booster: Booster, fmap: PathLike = "", num_trees: int = 0, rankdir: Optional[str] = None, ax: Optional[Axes] = None, **kwargs: Any, ) -> Axes: """Plot specified tree. Parameters ---------- booster : Booster, XGBModel Booster or XGBModel instance fmap: str (optional) The name of feature map file num_trees : int, default 0 Specify the ordinal number of target tree rankdir : str, default "TB" Passed to graphviz via graph_attr ax : matplotlib Axes, default None Target axes instance. If None, new figure and axes will be created. kwargs : Other keywords passed to to_graphviz Returns ------- ax : matplotlib Axes """ try: from matplotlib import image from matplotlib import pyplot as plt except ImportError as e: raise ImportError("You must install matplotlib to plot tree") from e if ax is None: _, ax = plt.subplots(1, 1) g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs) s = BytesIO() s.write(g.pipe(format="png")) s.seek(0) img = image.imread(s) ax.imshow(img) ax.axis("off") return ax