Implement tree model dump with code generator. (#4602)

* Implement tree model dump with a code generator.

* Split up generators.
* Implement graphviz generator.
* Use pattern matching.

* [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package.


Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2019-06-26 15:20:44 +08:00 committed by GitHub
parent fe2de6f415
commit 8bdf15120a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 802 additions and 264 deletions

View File

@ -7,6 +7,8 @@
#ifndef XGBOOST_FEATURE_MAP_H_
#define XGBOOST_FEATURE_MAP_H_
#include <xgboost/logging.h>
#include <vector>
#include <string>
#include <cstring>

View File

@ -1419,7 +1419,7 @@ class Booster(object):
with_stats : bool, optional
Controls whether the split statistics are output.
dump_format : string, optional
Format of model dump. Can be 'text' or 'json'.
Format of model dump. Can be 'text', 'json' or 'dot'.
"""
length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)()

View File

@ -1,10 +1,7 @@
# coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name,
# pylint: disable=too-many-branches
# coding: utf-8
"""Plotting Library."""
from __future__ import absolute_import
import re
from io import BytesIO
import numpy as np
from .core import Booster
@ -61,7 +58,8 @@ def plot_importance(booster, ax=None, height=0.2,
raise ImportError('You must install matplotlib to plot importance')
if isinstance(booster, XGBModel):
importance = booster.get_booster().get_score(importance_type=importance_type)
importance = booster.get_booster().get_score(
importance_type=importance_type)
elif isinstance(booster, Booster):
importance = booster.get_score(importance_type=importance_type)
elif isinstance(booster, dict):
@ -117,56 +115,11 @@ def plot_importance(booster, ax=None, height=0.2,
return ax
_NODEPAT = re.compile(r'(\d+):\[(.+)\]')
_LEAFPAT = re.compile(r'(\d+):(leaf=.+)')
_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
def _parse_node(graph, text, condition_node_params, leaf_node_params):
"""parse dumped node"""
match = _NODEPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), **condition_node_params)
return node
match = _LEAFPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), **leaf_node_params)
return node
raise ValueError('Unable to parse node: {0}'.format(text))
def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
"""parse dumped edge"""
try:
match = _EDGEPAT.match(text)
if match is not None:
yes, no, missing = match.groups()
if yes == missing:
graph.edge(node, yes, label='yes, missing', color=yes_color)
graph.edge(node, no, label='no', color=no_color)
else:
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no, missing', color=no_color)
return
except ValueError:
pass
match = _EDGEPAT2.match(text)
if match is not None:
yes, no = match.groups()
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no', color=no_color)
return
raise ValueError('Unable to parse edge: {0}'.format(text))
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
yes_color='#0000FF', no_color='#FF0000',
def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
yes_color=None, no_color=None,
condition_node_params=None, leaf_node_params=None, **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot the
returned graphiz instance. Otherwise, you should call .render() method
"""Convert specified tree to graphviz instance. IPython can automatically plot
the returned graphiz instance. Otherwise, you should call .render() method
of the returned graphiz instance.
Parameters
@ -184,64 +137,77 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
no_color : str, default '#FF0000'
Edge color when doesn't meet the node condition.
condition_node_params : dict (optional)
condition node configuration,
{'shape':'box',
'style':'filled,rounded',
'fillcolor':'#78bceb'}
Condition node configuration for for graphviz. Example:
.. code-block:: python
{'shape': 'box',
'style': 'filled,rounded',
'fillcolor': '#78bceb'}
leaf_node_params : dict (optional)
leaf node configuration
{'shape':'box',
'style':'filled',
'fillcolor':'#e48038'}
Leaf node configuration for graphviz. Example:
kwargs :
Other keywords passed to graphviz graph_attr
.. code-block:: python
{'shape': 'box',
'style': 'filled',
'fillcolor': '#e48038'}
kwargs : Other keywords passed to graphviz graph_attr, E.g.:
``graph [ {key} = {value} ]``
Returns
-------
ax : matplotlib Axes
graph: graphviz.Source
"""
if condition_node_params is None:
condition_node_params = {}
if leaf_node_params is None:
leaf_node_params = {}
try:
from graphviz import Digraph
from graphviz import Source
except ImportError:
raise ImportError('You must install graphviz to plot tree')
if not isinstance(booster, (Booster, XGBModel)):
raise ValueError('booster must be Booster or XGBModel instance')
if isinstance(booster, XGBModel):
booster = booster.get_booster()
tree = booster.get_dump(fmap=fmap)[num_trees]
tree = tree.split()
# squash everything back into kwargs again for compatibility
parameters = 'dot'
extra = {}
for key, value in kwargs.items():
extra[key] = value
kwargs = kwargs.copy()
kwargs.update({'rankdir': rankdir})
graph = Digraph(graph_attr=kwargs)
for i, text in enumerate(tree):
if text[0].isdigit():
node = _parse_node(
graph, text, condition_node_params=condition_node_params,
leaf_node_params=leaf_node_params)
if rankdir is not None:
kwargs['graph_attrs'] = {}
kwargs['graph_attrs']['rankdir'] = rankdir
for key, value in extra.items():
if 'graph_attrs' in kwargs.keys():
kwargs['graph_attrs'][key] = value
else:
if i == 0:
# 1st string must be node
raise ValueError('Unable to parse given string as tree')
_parse_edge(graph, node, text, yes_color=yes_color,
no_color=no_color)
kwargs['graph_attrs'] = {}
del kwargs[key]
return graph
if yes_color is not None or no_color is not None:
kwargs['edge'] = {}
if yes_color is not None:
kwargs['edge']['yes_color'] = yes_color
if no_color is not None:
kwargs['edge']['no_color'] = no_color
if condition_node_params is not None:
kwargs['condition_node_params'] = condition_node_params
if leaf_node_params is not None:
kwargs['leaf_node_params'] = leaf_node_params
if kwargs:
parameters += ':'
parameters += str(kwargs)
tree = booster.get_dump(
fmap=fmap,
dump_format=parameters)[num_trees]
g = Source(tree)
return g
def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs):
"""Plot specified tree.
Parameters
@ -252,7 +218,7 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
The name of feature map file
num_trees : int, default 0
Specify the ordinal number of target tree
rankdir : str, default "UT"
rankdir : str, default "TB"
Passed to graphiz via graph_attr
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
@ -264,18 +230,17 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
import matplotlib.image as image
from matplotlib import pyplot as plt
from matplotlib import image
except ImportError:
raise ImportError('You must install matplotlib to plot tree')
if ax is None:
_, ax = plt.subplots(1, 1)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees,
rankdir=rankdir, **kwargs)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
**kwargs)
s = BytesIO()
s.write(g.pipe(format='png'))

View File

@ -1033,19 +1033,21 @@ inline void XGBoostDumpModelImpl(
*out_models = dmlc::BeginPtr(charp_vecs);
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
}
XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
const char* fmap,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
const char* fmap,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
return XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models);
}
XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
const char* fmap,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
const char* fmap,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
API_BEGIN();
CHECK_HANDLE();
FeatureMap featmap;
@ -1060,23 +1062,24 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
}
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int fnum,
const char** fname,
const char** ftype,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
int fnum,
const char** fname,
const char** ftype,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats,
"text", len, out_models);
"text", len, out_models);
}
XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
int fnum,
const char** fname,
const char** ftype,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
int fnum,
const char** fname,
const char** ftype,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
API_BEGIN();
CHECK_HANDLE();
FeatureMap featmap;

View File

@ -10,7 +10,7 @@
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
#include <nvToolsExt.h>
#endif
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
namespace xgboost {
namespace common {
@ -98,7 +98,7 @@ struct Monitor {
stats.timer.Start();
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
stats.nvtx_id = nvtxRangeStartA(name.c_str());
#endif
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
}
}
void StopCuda(const std::string &name) {
@ -108,7 +108,7 @@ struct Monitor {
stats.count++;
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
nvtxRangeEnd(stats.nvtx_id);
#endif
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
}
}
};

View File

@ -1,14 +1,19 @@
/*!
* Copyright 2015 by Contributors
* Copyright 2015-2019 by Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
#include <dmlc/registry.h>
#include <dmlc/json.h>
#include <xgboost/tree_model.h>
#include <xgboost/logging.h>
#include <sstream>
#include <limits>
#include <cmath>
#include <iomanip>
#include "./param.h"
#include "param.h"
namespace xgboost {
// register tree parameter
@ -17,158 +22,602 @@ DMLC_REGISTER_PARAMETER(TreeParam);
namespace tree {
DMLC_REGISTER_PARAMETER(TrainParam);
}
// internal function to dump regression tree to text
void DumpRegTree(std::stringstream& fo, // NOLINT(*)
const RegTree& tree,
const FeatureMap& fmap,
int nid, int depth, int add_comma,
bool with_stats, std::string format) {
int float_max_precision = std::numeric_limits<bst_float>::max_digits10;
if (format == "json") {
if (add_comma) {
fo << ",";
}
if (depth != 0) {
fo << std::endl;
}
for (int i = 0; i < depth + 1; ++i) {
fo << " ";
}
} else {
for (int i = 0; i < depth; ++i) {
fo << '\t';
}
/*!
* \brief Base class for dump model implementation, modeling closely after code generator.
*/
class TreeGenerator {
protected:
static int32_t constexpr kFloatMaxPrecision =
std::numeric_limits<bst_float>::max_digits10;
FeatureMap const& fmap_;
std::stringstream ss_;
bool const with_stats_;
template <typename Float>
static std::string ToStr(Float value) {
static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values.");
std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision) << value;
return ss.str();
}
if (tree[nid].IsLeaf()) {
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"leaf\": " << std::setprecision(float_max_precision) << tree[nid].LeafValue();
if (with_stats) {
fo << ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
}
fo << " }";
} else {
fo << nid << ":leaf=" << std::setprecision(float_max_precision) << tree[nid].LeafValue();
if (with_stats) {
fo << ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
}
fo << '\n';
static std::string Tabs(uint32_t n) {
std::string res;
for (uint32_t i = 0; i < n; ++i) {
res += '\t';
}
} else {
// right then left,
bst_float cond = tree[nid].SplitCond();
const unsigned split_index = tree[nid].SplitIndex();
if (split_index < fmap.Size()) {
switch (fmap.type(split_index)) {
return res;
}
/* \brief Find the first occurance of key in input and replace it with corresponding
* value.
*/
static std::string Match(std::string const& input,
std::map<std::string, std::string> const& replacements) {
std::string result = input;
for (auto const& kv : replacements) {
auto pos = result.find(kv.first);
CHECK_NE(pos, std::string::npos);
result.replace(pos, kv.first.length(), kv.second);
}
return result;
}
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) {
return "";
}
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) {
return "";
}
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) {
return "";
}
virtual std::string NodeStat(RegTree const& tree, int32_t nid) {
return "";
}
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
auto const split_index = tree[nid].SplitIndex();
std::string result;
if (split_index < fmap_.Size()) {
switch (fmap_.type(split_index)) {
case FeatureMap::kIndicator: {
int nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
<< ", \"yes\": " << nyes
<< ", \"no\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[" << fmap.Name(split_index) << "] yes=" << nyes
<< ",no=" << tree[nid].DefaultChild();
}
result = this->Indicator(tree, nid, depth);
break;
}
case FeatureMap::kInteger: {
const bst_float floored = std::floor(cond);
const int integer_threshold
= (floored == cond) ? static_cast<int>(floored)
: static_cast<int>(floored) + 1;
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
<< ", \"split_condition\": " << integer_threshold
<< ", \"yes\": " << tree[nid].LeftChild()
<< ", \"no\": " << tree[nid].RightChild()
<< ", \"missing\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[" << fmap.Name(split_index) << "<"
<< integer_threshold
<< "] yes=" << tree[nid].LeftChild()
<< ",no=" << tree[nid].RightChild()
<< ",missing=" << tree[nid].DefaultChild();
}
result = this->Integer(tree, nid, depth);
break;
}
case FeatureMap::kFloat:
case FeatureMap::kQuantitive: {
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
<< ", \"split_condition\": " << std::setprecision(float_max_precision) << cond
<< ", \"yes\": " << tree[nid].LeftChild()
<< ", \"no\": " << tree[nid].RightChild()
<< ", \"missing\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[" << fmap.Name(split_index)
<< "<" << std::setprecision(float_max_precision) << cond
<< "] yes=" << tree[nid].LeftChild()
<< ",no=" << tree[nid].RightChild()
<< ",missing=" << tree[nid].DefaultChild();
}
result = this->Quantitive(tree, nid, depth);
break;
}
default: LOG(FATAL) << "unknown fmap type";
}
default:
LOG(FATAL) << "Unknown feature map type.";
}
} else {
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": " << split_index
<< ", \"split_condition\": " << std::setprecision(float_max_precision) << cond
<< ", \"yes\": " << tree[nid].LeftChild()
<< ", \"no\": " << tree[nid].RightChild()
<< ", \"missing\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[f" << split_index << "<"<< std::setprecision(float_max_precision) << cond
<< "] yes=" << tree[nid].LeftChild()
<< ",no=" << tree[nid].RightChild()
<< ",missing=" << tree[nid].DefaultChild();
result = this->PlainNode(tree, nid, depth);
}
return result;
}
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
public:
TreeGenerator(FeatureMap const& _fmap, bool with_stats) :
fmap_{_fmap}, with_stats_{with_stats} {}
virtual ~TreeGenerator() = default;
virtual void BuildTree(RegTree const& tree) {
ss_ << this->BuildTree(tree, 0, 0);
}
std::string Str() const {
return ss_.str();
}
static TreeGenerator* Create(std::string const& attrs, FeatureMap const& fmap,
bool with_stats);
};
struct TreeGenReg : public dmlc::FunctionRegEntryBase<
TreeGenReg,
std::function<TreeGenerator* (
FeatureMap const& fmap, std::string attrs, bool with_stats)> > {
};
} // namespace xgboost
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::TreeGenReg);
} // namespace dmlc
namespace xgboost {
TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const& fmap,
bool with_stats) {
auto pos = attrs.find(':');
std::string name;
std::string params;
if (pos != std::string::npos) {
name = attrs.substr(0, pos);
params = attrs.substr(pos+1, attrs.length() - pos - 1);
// Eliminate all occurances of single quote string.
size_t pos = std::string::npos;
while ((pos = params.find('\'')) != std::string::npos) {
params.replace(pos, 1, "\"");
}
} else {
name = attrs;
}
auto *e = ::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown Model Builder:" << name;
}
auto p_io_builder = (e->body)(fmap, params, with_stats);
return p_io_builder;
}
#define XGBOOST_REGISTER_TREE_IO(UniqueId, Name) \
static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeGenReg& \
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
class TextGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
public:
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
static std::string kStatTemplate = ",cover={cover}";
std::string result = SuperT::Match(
kLeafTemplate,
{{"{tabs}", SuperT::Tabs(depth)},
{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{stats}", with_stats_ ?
SuperT::Match(kStatTemplate,
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
auto split_index = tree[nid].SplitIndex();
std::string result = SuperT::Match(
kIndicatorTemplate,
{{"{nid}", std::to_string(nid)},
{"{fname}", fmap_.Name(split_index)},
{"{yes}", std::to_string(nyes)},
{"{no}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string SplitNodeImpl(
RegTree const& tree, int32_t nid, std::string const& template_str,
std::string cond, uint32_t depth) {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
{{"{tabs}", SuperT::Tabs(depth)},
{"{nid}", std::to_string(nid)},
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
std::to_string(split_index)},
{"{cond}", cond},
{"{left}", std::to_string(tree[nid].LeftChild())},
{"{right}", std::to_string(tree[nid].RightChild())},
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kIntegerTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
const bst_float floored = std::floor(cond);
const int32_t integer_threshold
= (floored == cond) ? static_cast<int>(floored)
: static_cast<int>(floored) + 1;
return SplitNodeImpl(tree, nid, kIntegerTemplate,
std::to_string(integer_threshold), depth);
}
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kQuantitiveTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}
std::string NodeStat(RegTree const& tree, int32_t nid) override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
return result;
}
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
if (tree[nid].IsLeaf()) {
return this->LeafNode(tree, nid, depth);
}
static std::string const kNodeTemplate = "{parent}{stat}\n{left}\n{right}";
auto result = SuperT::Match(
kNodeTemplate,
{{"{parent}", this->SplitNode(tree, nid, depth)},
{"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
return result;
}
void BuildTree(RegTree const& tree) override {
static std::string const& kTreeTemplate = "{nodes}\n";
auto result = SuperT::Match(
kTreeTemplate,
{{"{nodes}", this->BuildTree(tree, 0, 0)}});
ss_ << result;
}
};
XGBOOST_REGISTER_TREE_IO(TextGenerator, "text")
.describe("Dump text representation of tree")
.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) {
return new TextGenerator(fmap, attrs, with_stats);
});
class JsonGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
public:
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}
std::string Indent(uint32_t depth) {
std::string result;
for (uint32_t i = 0; i < depth + 1; ++i) {
result += " ";
}
return result;
}
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kLeafTemplate =
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
static std::string const kStatTemplate =
R"S(, "cover": {sum_hess} )S";
std::string result = SuperT::Match(
kLeafTemplate,
{{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{stat}", with_stats_ ? SuperT::Match(
kStatTemplate,
{{"{sum_hess}",
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
static std::string const kIndicatorTemplate =
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID";
auto split_index = tree[nid].SplitIndex();
auto result = SuperT::Match(
kIndicatorTemplate,
{{"{nid}", std::to_string(nid)},
{"{depth}", std::to_string(depth)},
{"{fname}", fmap_.Name(split_index)},
{"{yes}", std::to_string(nyes)},
{"{no}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string SplitNodeImpl(RegTree const& tree, int32_t nid,
std::string const& template_str, std::string cond, uint32_t depth) {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
{{"{nid}", std::to_string(nid)},
{"{depth}", std::to_string(depth)},
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
std::to_string(split_index)},
{"{cond}", cond},
{"{left}", std::to_string(tree[nid].LeftChild())},
{"{right}", std::to_string(tree[nid].RightChild())},
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto cond = tree[nid].SplitCond();
const bst_float floored = std::floor(cond);
const int32_t integer_threshold
= (floored == cond) ? static_cast<int32_t>(floored)
: static_cast<int32_t>(floored) + 1;
static std::string const kIntegerTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
return SplitNodeImpl(tree, nid, kIntegerTemplate,
std::to_string(integer_threshold), depth);
}
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kQuantitiveTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
bst_float cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}
std::string NodeStat(RegTree const& tree, int32_t nid) override {
static std::string kStatTemplate =
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
auto result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
return result;
}
std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string properties = SuperT::SplitNode(tree, nid, depth);
static std::string const kSplitNodeTemplate =
"{{properties} {stat}, \"children\": [{left}, {right}\n{indent}]}";
auto result = SuperT::Match(
kSplitNodeTemplate,
{{"{properties}", properties},
{"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)},
{"{indent}", this->Indent(depth)}});
return result;
}
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kNodeTemplate = "{newline}{indent}{nodes}";
auto result = SuperT::Match(
kNodeTemplate,
{{"{newline}", depth == 0 ? "" : "\n"},
{"{indent}", Indent(depth)},
{"{nodes}", tree[nid].IsLeaf() ? this->LeafNode(tree, nid, depth) :
this->SplitNode(tree, nid, depth)}});
return result;
}
};
XGBOOST_REGISTER_TREE_IO(JsonGenerator, "json")
.describe("Dump json representation of tree")
.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) {
return new JsonGenerator(fmap, attrs, with_stats);
});
struct GraphvizParam : public dmlc::Parameter<GraphvizParam> {
std::string yes_color;
std::string no_color;
std::string rankdir;
std::string condition_node_params;
std::string leaf_node_params;
std::string graph_attrs;
DMLC_DECLARE_PARAMETER(GraphvizParam){
DMLC_DECLARE_FIELD(yes_color)
.set_default("#0000FF")
.describe("Edge color when meets the node condition.");
DMLC_DECLARE_FIELD(no_color)
.set_default("#FF0000")
.describe("Edge color when doesn't meet the node condition.");
DMLC_DECLARE_FIELD(rankdir)
.set_default("TB")
.describe("Passed to graphiz via graph_attr.");
DMLC_DECLARE_FIELD(condition_node_params)
.set_default("")
.describe("Conditional node configuration");
DMLC_DECLARE_FIELD(leaf_node_params)
.set_default("")
.describe("Leaf node configuration");
DMLC_DECLARE_FIELD(graph_attrs)
.set_default("")
.describe("Any other extra attributes for graphviz `graph_attr`.");
}
};
DMLC_REGISTER_PARAMETER(GraphvizParam);
class GraphvizGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
std::stringstream& ss_;
GraphvizParam param_;
public:
GraphvizGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats), ss_{SuperT::ss_} {
param_.InitAllowUnknown(std::map<std::string, std::string>{});
using KwArg = std::map<std::string, std::map<std::string, std::string>>;
KwArg kwargs;
if (attrs.length() != 0) {
std::istringstream iss(attrs);
try {
dmlc::JSONReader reader(&iss);
reader.Read(&kwargs);
} catch(dmlc::Error const& e) {
LOG(FATAL) << "Failed to parse graphviz parameters:\n\t"
<< attrs << "\n"
<< "With error:\n"
<< e.what();
}
}
if (with_stats) {
if (format == "json") {
fo << ", \"gain\": " << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg
<< ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
} else {
fo << ",gain=" << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg
<< ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
// This turns out to be tricky, as `dmlc::Parameter::Load(JSONReader*)` doesn't
// support loading nested json objects.
if (kwargs.find("condition_node_params") != kwargs.cend()) {
auto const& cnp = kwargs["condition_node_params"];
for (auto const& kv : cnp) {
param_.condition_node_params += kv.first + '=' + "\"" + kv.second + "\" ";
}
kwargs.erase("condition_node_params");
}
if (format == "json") {
fo << ", \"children\": [";
} else {
fo << '\n';
}
DumpRegTree(fo, tree, fmap, tree[nid].LeftChild(), depth + 1, false, with_stats, format);
DumpRegTree(fo, tree, fmap, tree[nid].RightChild(), depth + 1, true, with_stats, format);
if (format == "json") {
fo << std::endl;
for (int i = 0; i < depth + 1; ++i) {
fo << " ";
if (kwargs.find("leaf_node_params") != kwargs.cend()) {
auto const& lnp = kwargs["leaf_node_params"];
for (auto const& kv : lnp) {
param_.leaf_node_params += kv.first + '=' + "\"" + kv.second + "\" ";
}
fo << "]}";
kwargs.erase("leaf_node_params");
}
if (kwargs.find("edge") != kwargs.cend()) {
if (kwargs["edge"].find("yes_color") != kwargs["edge"].cend()) {
param_.yes_color = kwargs["edge"]["yes_color"];
}
if (kwargs["edge"].find("no_color") != kwargs["edge"].cend()) {
param_.no_color = kwargs["edge"]["no_color"];
}
kwargs.erase("edge");
}
auto const& extra = kwargs["graph_attrs"];
static std::string const kGraphTemplate = " graph [ {key}=\"{value}\" ]\n";
for (auto const& kv : extra) {
param_.graph_attrs += SuperT::Match(kGraphTemplate,
{{"{key}", kv.first},
{"{value}", kv.second}});
}
kwargs.erase("graph_attrs");
if (kwargs.size() != 0) {
std::stringstream ss;
ss << "The following parameters for graphviz are not recognized:\n";
for (auto kv : kwargs) {
ss << kv.first << ", ";
}
LOG(WARNING) << ss.str();
}
}
}
protected:
// Only indicator is different, so we combine all different node types into this
// function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto split = tree[nid].SplitIndex();
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
" {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
// Indicator only has fname.
bool has_less = (split >= fmap_.Size()) || fmap_.type(split) != FeatureMap::kIndicator;
std::string result = SuperT::Match(kNodeTemplate, {
{"{nid}", std::to_string(nid)},
{"{fname}", split < fmap_.Size() ? fmap_.Name(split) :
'f' + std::to_string(split)},
{"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
{"{params}", param_.condition_node_params}});
static std::string const kEdgeTemplate =
" {nid} -> {child} [label=\"{is_missing}\" color=\"{color}\"]\n";
auto MatchFn = SuperT::Match; // mingw failed to capture protected fn.
auto BuildEdge =
[&tree, nid, MatchFn, this](int32_t child) {
bool is_missing = tree[nid].DefaultChild() == child;
std::string buffer = MatchFn(kEdgeTemplate, {
{"{nid}", std::to_string(nid)},
{"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{is_missing}", is_missing ? "yes, missing": "no"}});
return buffer;
};
result += BuildEdge(tree[nid].LeftChild());
result += BuildEdge(tree[nid].RightChild());
return result;
};
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
auto result = SuperT::Match(kLeafTemplate, {
{"{nid}", std::to_string(nid)},
{"{leaf-value}", ToStr(tree[nid].LeafValue())},
{"{params}", param_.leaf_node_params}});
return result;
};
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
if (tree[nid].IsLeaf()) {
return this->LeafNode(tree, nid, depth);
}
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
auto result = SuperT::Match(
kNodeTemplate,
{{"{parent}", this->PlainNode(tree, nid, depth)},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
return result;
}
void BuildTree(RegTree const& tree) override {
static std::string const kTreeTemplate =
"digraph {\n"
" graph [ rankdir={rankdir} ]\n"
"{graph_attrs}\n"
"{nodes}}";
auto result = SuperT::Match(
kTreeTemplate,
{{"{rankdir}", param_.rankdir},
{"{graph_attrs}", param_.graph_attrs},
{"{nodes}", this->BuildTree(tree, 0, 0)}});
ss_ << result;
};
};
XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
.describe("Dump graphviz representation of tree")
.set_body([](FeatureMap const& fmap, std::string attrs, bool with_stats) {
return new GraphvizGenerator(fmap, attrs, with_stats);
});
std::string RegTree::DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const {
std::stringstream fo("");
for (int i = 0; i < param.num_roots; ++i) {
DumpRegTree(fo, *this, fmap, i, 0, false, with_stats, format);
std::unique_ptr<TreeGenerator> builder {
TreeGenerator::Create(format, fmap, with_stats)
};
for (int32_t i = 0; i < param.num_roots; ++i) {
builder->BuildTree(*this);
}
return fo.str();
std::string result = builder->Str();
return result;
}
void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() == num_nodes) {

View File

@ -144,7 +144,7 @@ __global__ void CubScanByKeyL1(
int previousKey = __shfl_up_sync(0xFFFFFFFF, myKey, 1);
#else
int previousKey = __shfl_up(myKey, 1);
#endif
#endif // (__CUDACC_VER_MAJOR__ >= 9)
// Collectively compute the block-wide exclusive prefix sum
BlockScan(temp_storage)
.ExclusiveScan(threadData, threadData, rootPair, AddByKey());

View File

@ -101,4 +101,121 @@ TEST(Tree, AllocateNode) {
ASSERT_TRUE(nodes.at(1).IsLeaf());
ASSERT_TRUE(nodes.at(2).IsLeaf());
}
RegTree ConstructTree() {
RegTree tree;
tree.ExpandNode(
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
/*default_left=*/true,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
return tree;
}
TEST(Tree, DumpJson) {
auto tree = ConstructTree();
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, "json");
size_t n_leaves = 0;
size_t iter = 0;
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
size_t n_conditions = 0;
iter = 0;
while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
n_conditions++;
}
ASSERT_EQ(n_conditions, 3);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, "json");
ASSERT_NE(str.find(R"("split": "feat_0")"), std::string::npos);
ASSERT_NE(str.find(R"("split": "feat_1")"), std::string::npos);
ASSERT_NE(str.find(R"("split": "feat_2")"), std::string::npos);
str = tree.DumpModel(fmap, false, "json");
ASSERT_EQ(str.find("cover"), std::string::npos);
}
TEST(Tree, DumpText) {
auto tree = ConstructTree();
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, "text");
size_t n_leaves = 0;
size_t iter = 0;
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
iter = 0;
size_t n_conditions = 0;
while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
n_conditions++;
}
ASSERT_EQ(n_conditions, 3);
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
ASSERT_NE(str.find("[f2<2]"), std::string::npos);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, "text");
ASSERT_NE(str.find("[feat_0]"), std::string::npos);
ASSERT_NE(str.find("[feat_1<1]"), std::string::npos);
ASSERT_NE(str.find("[feat_2<2]"), std::string::npos);
str = tree.DumpModel(fmap, false, "text");
ASSERT_EQ(str.find("cover"), std::string::npos);
}
TEST(Tree, DumpDot) {
auto tree = ConstructTree();
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, "dot");
size_t n_leaves = 0;
size_t iter = 0;
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
size_t n_edges = 0;
iter = 0;
while ((iter = str.find("->", iter + 1)) != std::string::npos) {
n_edges++;
}
ASSERT_EQ(n_edges, 6);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, "dot");
ASSERT_NE(str.find(R"("feat_0")"), std::string::npos);
ASSERT_NE(str.find(R"(feat_1<1)"), std::string::npos);
ASSERT_NE(str.find(R"(feat_2<2)"), std::string::npos);
str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})");
ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos);
}
} // namespace xgboost

View File

@ -10,7 +10,7 @@ try:
import matplotlib
matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Digraph
from graphviz import Source
except ImportError:
pass
@ -57,7 +57,7 @@ class TestPlotting(unittest.TestCase):
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Digraph)
assert isinstance(g, Source)
ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes)

View File

@ -87,7 +87,6 @@ class TestSHAP(unittest.TestCase):
r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)"
r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)"
for tree in model.get_dump(with_stats=True):
lines = list(tree.splitlines())
trees.append([None for i in range(len(lines))])
for line in lines:

View File

@ -352,7 +352,7 @@ def test_sklearn_plotting():
matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Digraph
from graphviz import Source
ax = xgb.plot_importance(classifier)
assert isinstance(ax, Axes)
@ -362,7 +362,7 @@ def test_sklearn_plotting():
assert len(ax.patches) == 4
g = xgb.to_graphviz(classifier, num_trees=0)
assert isinstance(g, Digraph)
assert isinstance(g, Source)
ax = xgb.plot_tree(classifier, num_trees=0)
assert isinstance(ax, Axes)
@ -641,7 +641,8 @@ def test_XGBClassifier_resume():
X, Y = load_breast_cancer(return_X_y=True)
model1 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
model1 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model1.fit(X, Y)
pred1 = model1.predict(X)
@ -649,7 +650,8 @@ def test_XGBClassifier_resume():
# file name of stored xgb model
model1.save_model(model1_path)
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
model2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_path)
pred2 = model2.predict(X)
@ -660,7 +662,8 @@ def test_XGBClassifier_resume():
# file name of 'Booster' instance Xgb model
model1.get_booster().save_model(model1_booster_path)
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
model2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_booster_path)
pred2 = model2.predict(X)