diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index a5aca60e8..cf81ef4f0 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -153,7 +153,7 @@ def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'): raise ValueError('Unable to parse edge: {0}'.format(text)) -def to_graphviz(booster, num_trees=0, rankdir='UT', +def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT', yes_color='#0000FF', no_color='#FF0000', **kwargs): """Convert specified tree to graphviz instance. IPython can automatically plot the @@ -164,6 +164,8 @@ def to_graphviz(booster, num_trees=0, rankdir='UT', ---------- booster : Booster, XGBModel Booster or XGBModel instance + fmap: str (optional) + The name of feature map file num_trees : int, default 0 Specify the ordinal number of target tree rankdir : str, default "UT" @@ -191,7 +193,7 @@ def to_graphviz(booster, num_trees=0, rankdir='UT', if isinstance(booster, XGBModel): booster = booster.booster() - tree = booster.get_dump()[num_trees] + tree = booster.get_dump(fmap=fmap)[num_trees] tree = tree.split() kwargs = kwargs.copy() @@ -211,13 +213,15 @@ def to_graphviz(booster, num_trees=0, rankdir='UT', return graph -def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs): +def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs): """Plot specified tree. Parameters ---------- booster : Booster, XGBModel Booster or XGBModel instance + fmap: str (optional) + The name of feature map file num_trees : int, default 0 Specify the ordinal number of target tree rankdir : str, default "UT" @@ -242,7 +246,7 @@ def plot_tree(booster, num_trees=0, rankdir='UT', ax=None, **kwargs): if ax is None: _, ax = plt.subplots(1, 1) - g = to_graphviz(booster, num_trees=num_trees, rankdir=rankdir, **kwargs) + g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs) s = BytesIO() s.write(g.pipe(format='png'))