Implement tree model dump with code generator. (#4602)

* Implement tree model dump with a code generator.

* Split up generators.
* Implement graphviz generator.
* Use pattern matching.

* [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package.


Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2019-06-26 15:20:44 +08:00 committed by GitHub
parent fe2de6f415
commit 8bdf15120a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 802 additions and 264 deletions

View File

@ -7,6 +7,8 @@
#ifndef XGBOOST_FEATURE_MAP_H_ #ifndef XGBOOST_FEATURE_MAP_H_
#define XGBOOST_FEATURE_MAP_H_ #define XGBOOST_FEATURE_MAP_H_
#include <xgboost/logging.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <cstring> #include <cstring>

View File

@ -1419,7 +1419,7 @@ class Booster(object):
with_stats : bool, optional with_stats : bool, optional
Controls whether the split statistics are output. Controls whether the split statistics are output.
dump_format : string, optional dump_format : string, optional
Format of model dump. Can be 'text' or 'json'. Format of model dump. Can be 'text', 'json' or 'dot'.
""" """
length = c_bst_ulong() length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()

View File

@ -1,10 +1,7 @@
# coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name, # pylint: disable=too-many-locals, too-many-arguments, invalid-name,
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
# coding: utf-8
"""Plotting Library.""" """Plotting Library."""
from __future__ import absolute_import
import re
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
from .core import Booster from .core import Booster
@ -61,7 +58,8 @@ def plot_importance(booster, ax=None, height=0.2,
raise ImportError('You must install matplotlib to plot importance') raise ImportError('You must install matplotlib to plot importance')
if isinstance(booster, XGBModel): if isinstance(booster, XGBModel):
importance = booster.get_booster().get_score(importance_type=importance_type) importance = booster.get_booster().get_score(
importance_type=importance_type)
elif isinstance(booster, Booster): elif isinstance(booster, Booster):
importance = booster.get_score(importance_type=importance_type) importance = booster.get_score(importance_type=importance_type)
elif isinstance(booster, dict): elif isinstance(booster, dict):
@ -117,56 +115,11 @@ def plot_importance(booster, ax=None, height=0.2,
return ax return ax
_NODEPAT = re.compile(r'(\d+):\[(.+)\]') def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
_LEAFPAT = re.compile(r'(\d+):(leaf=.+)') yes_color=None, no_color=None,
_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
def _parse_node(graph, text, condition_node_params, leaf_node_params):
"""parse dumped node"""
match = _NODEPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), **condition_node_params)
return node
match = _LEAFPAT.match(text)
if match is not None:
node = match.group(1)
graph.node(node, label=match.group(2), **leaf_node_params)
return node
raise ValueError('Unable to parse node: {0}'.format(text))
def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
"""parse dumped edge"""
try:
match = _EDGEPAT.match(text)
if match is not None:
yes, no, missing = match.groups()
if yes == missing:
graph.edge(node, yes, label='yes, missing', color=yes_color)
graph.edge(node, no, label='no', color=no_color)
else:
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no, missing', color=no_color)
return
except ValueError:
pass
match = _EDGEPAT2.match(text)
if match is not None:
yes, no = match.groups()
graph.edge(node, yes, label='yes', color=yes_color)
graph.edge(node, no, label='no', color=no_color)
return
raise ValueError('Unable to parse edge: {0}'.format(text))
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
yes_color='#0000FF', no_color='#FF0000',
condition_node_params=None, leaf_node_params=None, **kwargs): condition_node_params=None, leaf_node_params=None, **kwargs):
"""Convert specified tree to graphviz instance. IPython can automatically plot the """Convert specified tree to graphviz instance. IPython can automatically plot
returned graphiz instance. Otherwise, you should call .render() method the returned graphiz instance. Otherwise, you should call .render() method
of the returned graphiz instance. of the returned graphiz instance.
Parameters Parameters
@ -184,64 +137,77 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
no_color : str, default '#FF0000' no_color : str, default '#FF0000'
Edge color when doesn't meet the node condition. Edge color when doesn't meet the node condition.
condition_node_params : dict (optional) condition_node_params : dict (optional)
condition node configuration, Condition node configuration for for graphviz. Example:
{'shape':'box',
'style':'filled,rounded', .. code-block:: python
'fillcolor':'#78bceb'}
{'shape': 'box',
'style': 'filled,rounded',
'fillcolor': '#78bceb'}
leaf_node_params : dict (optional) leaf_node_params : dict (optional)
leaf node configuration Leaf node configuration for graphviz. Example:
{'shape':'box',
'style':'filled',
'fillcolor':'#e48038'}
kwargs : .. code-block:: python
Other keywords passed to graphviz graph_attr
{'shape': 'box',
'style': 'filled',
'fillcolor': '#e48038'}
kwargs : Other keywords passed to graphviz graph_attr, E.g.:
``graph [ {key} = {value} ]``
Returns Returns
------- -------
ax : matplotlib Axes graph: graphviz.Source
""" """
if condition_node_params is None:
condition_node_params = {}
if leaf_node_params is None:
leaf_node_params = {}
try: try:
from graphviz import Digraph from graphviz import Source
except ImportError: except ImportError:
raise ImportError('You must install graphviz to plot tree') raise ImportError('You must install graphviz to plot tree')
if not isinstance(booster, (Booster, XGBModel)):
raise ValueError('booster must be Booster or XGBModel instance')
if isinstance(booster, XGBModel): if isinstance(booster, XGBModel):
booster = booster.get_booster() booster = booster.get_booster()
tree = booster.get_dump(fmap=fmap)[num_trees] # squash everything back into kwargs again for compatibility
tree = tree.split() parameters = 'dot'
extra = {}
for key, value in kwargs.items():
extra[key] = value
kwargs = kwargs.copy() if rankdir is not None:
kwargs.update({'rankdir': rankdir}) kwargs['graph_attrs'] = {}
graph = Digraph(graph_attr=kwargs) kwargs['graph_attrs']['rankdir'] = rankdir
for key, value in extra.items():
for i, text in enumerate(tree): if 'graph_attrs' in kwargs.keys():
if text[0].isdigit(): kwargs['graph_attrs'][key] = value
node = _parse_node(
graph, text, condition_node_params=condition_node_params,
leaf_node_params=leaf_node_params)
else: else:
if i == 0: kwargs['graph_attrs'] = {}
# 1st string must be node del kwargs[key]
raise ValueError('Unable to parse given string as tree')
_parse_edge(graph, node, text, yes_color=yes_color,
no_color=no_color)
return graph if yes_color is not None or no_color is not None:
kwargs['edge'] = {}
if yes_color is not None:
kwargs['edge']['yes_color'] = yes_color
if no_color is not None:
kwargs['edge']['no_color'] = no_color
if condition_node_params is not None:
kwargs['condition_node_params'] = condition_node_params
if leaf_node_params is not None:
kwargs['leaf_node_params'] = leaf_node_params
if kwargs:
parameters += ':'
parameters += str(kwargs)
tree = booster.get_dump(
fmap=fmap,
dump_format=parameters)[num_trees]
g = Source(tree)
return g
def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs): def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs):
"""Plot specified tree. """Plot specified tree.
Parameters Parameters
@ -252,7 +218,7 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
The name of feature map file The name of feature map file
num_trees : int, default 0 num_trees : int, default 0
Specify the ordinal number of target tree Specify the ordinal number of target tree
rankdir : str, default "UT" rankdir : str, default "TB"
Passed to graphiz via graph_attr Passed to graphiz via graph_attr
ax : matplotlib Axes, default None ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created. Target axes instance. If None, new figure and axes will be created.
@ -264,18 +230,17 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
ax : matplotlib Axes ax : matplotlib Axes
""" """
try: try:
import matplotlib.pyplot as plt from matplotlib import pyplot as plt
import matplotlib.image as image from matplotlib import image
except ImportError: except ImportError:
raise ImportError('You must install matplotlib to plot tree') raise ImportError('You must install matplotlib to plot tree')
if ax is None: if ax is None:
_, ax = plt.subplots(1, 1) _, ax = plt.subplots(1, 1)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
rankdir=rankdir, **kwargs) **kwargs)
s = BytesIO() s = BytesIO()
s.write(g.pipe(format='png')) s.write(g.pipe(format='png'))

View File

@ -1033,19 +1033,21 @@ inline void XGBoostDumpModelImpl(
*out_models = dmlc::BeginPtr(charp_vecs); *out_models = dmlc::BeginPtr(charp_vecs);
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size()); *len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
} }
XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
const char* fmap, const char* fmap,
int with_stats, int with_stats,
xgboost::bst_ulong* len, xgboost::bst_ulong* len,
const char*** out_models) { const char*** out_models) {
return XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models); return XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models);
} }
XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
const char* fmap, const char* fmap,
int with_stats, int with_stats,
const char *format, const char *format,
xgboost::bst_ulong* len, xgboost::bst_ulong* len,
const char*** out_models) { const char*** out_models) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
FeatureMap featmap; FeatureMap featmap;
@ -1060,23 +1062,24 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
} }
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int fnum, int fnum,
const char** fname, const char** fname,
const char** ftype, const char** ftype,
int with_stats, int with_stats,
xgboost::bst_ulong* len, xgboost::bst_ulong* len,
const char*** out_models) { const char*** out_models) {
return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats, return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats,
"text", len, out_models); "text", len, out_models);
} }
XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
int fnum, int fnum,
const char** fname, const char** fname,
const char** ftype, const char** ftype,
int with_stats, int with_stats,
const char *format, const char *format,
xgboost::bst_ulong* len, xgboost::bst_ulong* len,
const char*** out_models) { const char*** out_models) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
FeatureMap featmap; FeatureMap featmap;

View File

@ -10,7 +10,7 @@
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) #if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
#include <nvToolsExt.h> #include <nvToolsExt.h>
#endif #endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -98,7 +98,7 @@ struct Monitor {
stats.timer.Start(); stats.timer.Start();
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) #if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
stats.nvtx_id = nvtxRangeStartA(name.c_str()); stats.nvtx_id = nvtxRangeStartA(name.c_str());
#endif #endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
} }
} }
void StopCuda(const std::string &name) { void StopCuda(const std::string &name) {
@ -108,7 +108,7 @@ struct Monitor {
stats.count++; stats.count++;
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) #if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
nvtxRangeEnd(stats.nvtx_id); nvtxRangeEnd(stats.nvtx_id);
#endif #endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
} }
} }
}; };

View File

@ -1,14 +1,19 @@
/*! /*!
* Copyright 2015 by Contributors * Copyright 2015-2019 by Contributors
* \file tree_model.cc * \file tree_model.cc
* \brief model structure for tree * \brief model structure for tree
*/ */
#include <dmlc/registry.h>
#include <dmlc/json.h>
#include <xgboost/tree_model.h> #include <xgboost/tree_model.h>
#include <xgboost/logging.h>
#include <sstream> #include <sstream>
#include <limits> #include <limits>
#include <cmath> #include <cmath>
#include <iomanip> #include <iomanip>
#include "./param.h"
#include "param.h"
namespace xgboost { namespace xgboost {
// register tree parameter // register tree parameter
@ -17,158 +22,602 @@ DMLC_REGISTER_PARAMETER(TreeParam);
namespace tree { namespace tree {
DMLC_REGISTER_PARAMETER(TrainParam); DMLC_REGISTER_PARAMETER(TrainParam);
} }
// internal function to dump regression tree to text
void DumpRegTree(std::stringstream& fo, // NOLINT(*) /*!
const RegTree& tree, * \brief Base class for dump model implementation, modeling closely after code generator.
const FeatureMap& fmap, */
int nid, int depth, int add_comma, class TreeGenerator {
bool with_stats, std::string format) { protected:
int float_max_precision = std::numeric_limits<bst_float>::max_digits10; static int32_t constexpr kFloatMaxPrecision =
if (format == "json") { std::numeric_limits<bst_float>::max_digits10;
if (add_comma) { FeatureMap const& fmap_;
fo << ","; std::stringstream ss_;
} bool const with_stats_;
if (depth != 0) {
fo << std::endl; template <typename Float>
} static std::string ToStr(Float value) {
for (int i = 0; i < depth + 1; ++i) { static_assert(std::is_floating_point<Float>::value,
fo << " "; "Use std::to_string instead for non-floating point values.");
} std::stringstream ss;
} else { ss << std::setprecision(kFloatMaxPrecision) << value;
for (int i = 0; i < depth; ++i) { return ss.str();
fo << '\t';
}
} }
if (tree[nid].IsLeaf()) {
if (format == "json") { static std::string Tabs(uint32_t n) {
fo << "{ \"nodeid\": " << nid std::string res;
<< ", \"leaf\": " << std::setprecision(float_max_precision) << tree[nid].LeafValue(); for (uint32_t i = 0; i < n; ++i) {
if (with_stats) { res += '\t';
fo << ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
}
fo << " }";
} else {
fo << nid << ":leaf=" << std::setprecision(float_max_precision) << tree[nid].LeafValue();
if (with_stats) {
fo << ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
}
fo << '\n';
} }
} else { return res;
// right then left, }
bst_float cond = tree[nid].SplitCond(); /* \brief Find the first occurance of key in input and replace it with corresponding
const unsigned split_index = tree[nid].SplitIndex(); * value.
if (split_index < fmap.Size()) { */
switch (fmap.type(split_index)) { static std::string Match(std::string const& input,
std::map<std::string, std::string> const& replacements) {
std::string result = input;
for (auto const& kv : replacements) {
auto pos = result.find(kv.first);
CHECK_NE(pos, std::string::npos);
result.replace(pos, kv.first.length(), kv.second);
}
return result;
}
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) {
return "";
}
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) {
return "";
}
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) {
return "";
}
virtual std::string NodeStat(RegTree const& tree, int32_t nid) {
return "";
}
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
auto const split_index = tree[nid].SplitIndex();
std::string result;
if (split_index < fmap_.Size()) {
switch (fmap_.type(split_index)) {
case FeatureMap::kIndicator: { case FeatureMap::kIndicator: {
int nyes = tree[nid].DefaultLeft() ? result = this->Indicator(tree, nid, depth);
tree[nid].RightChild() : tree[nid].LeftChild();
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
<< ", \"yes\": " << nyes
<< ", \"no\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[" << fmap.Name(split_index) << "] yes=" << nyes
<< ",no=" << tree[nid].DefaultChild();
}
break; break;
} }
case FeatureMap::kInteger: { case FeatureMap::kInteger: {
const bst_float floored = std::floor(cond); result = this->Integer(tree, nid, depth);
const int integer_threshold
= (floored == cond) ? static_cast<int>(floored)
: static_cast<int>(floored) + 1;
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
<< ", \"split_condition\": " << integer_threshold
<< ", \"yes\": " << tree[nid].LeftChild()
<< ", \"no\": " << tree[nid].RightChild()
<< ", \"missing\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[" << fmap.Name(split_index) << "<"
<< integer_threshold
<< "] yes=" << tree[nid].LeftChild()
<< ",no=" << tree[nid].RightChild()
<< ",missing=" << tree[nid].DefaultChild();
}
break; break;
} }
case FeatureMap::kFloat: case FeatureMap::kFloat:
case FeatureMap::kQuantitive: { case FeatureMap::kQuantitive: {
if (format == "json") { result = this->Quantitive(tree, nid, depth);
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
<< ", \"split_condition\": " << std::setprecision(float_max_precision) << cond
<< ", \"yes\": " << tree[nid].LeftChild()
<< ", \"no\": " << tree[nid].RightChild()
<< ", \"missing\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[" << fmap.Name(split_index)
<< "<" << std::setprecision(float_max_precision) << cond
<< "] yes=" << tree[nid].LeftChild()
<< ",no=" << tree[nid].RightChild()
<< ",missing=" << tree[nid].DefaultChild();
}
break; break;
} }
default: LOG(FATAL) << "unknown fmap type"; default:
} LOG(FATAL) << "Unknown feature map type.";
}
} else { } else {
if (format == "json") { result = this->PlainNode(tree, nid, depth);
fo << "{ \"nodeid\": " << nid }
<< ", \"depth\": " << depth return result;
<< ", \"split\": " << split_index }
<< ", \"split_condition\": " << std::setprecision(float_max_precision) << cond
<< ", \"yes\": " << tree[nid].LeftChild() virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
<< ", \"no\": " << tree[nid].RightChild() virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
<< ", \"missing\": " << tree[nid].DefaultChild();
} else { public:
fo << nid << ":[f" << split_index << "<"<< std::setprecision(float_max_precision) << cond TreeGenerator(FeatureMap const& _fmap, bool with_stats) :
<< "] yes=" << tree[nid].LeftChild() fmap_{_fmap}, with_stats_{with_stats} {}
<< ",no=" << tree[nid].RightChild() virtual ~TreeGenerator() = default;
<< ",missing=" << tree[nid].DefaultChild();
virtual void BuildTree(RegTree const& tree) {
ss_ << this->BuildTree(tree, 0, 0);
}
std::string Str() const {
return ss_.str();
}
static TreeGenerator* Create(std::string const& attrs, FeatureMap const& fmap,
bool with_stats);
};
struct TreeGenReg : public dmlc::FunctionRegEntryBase<
TreeGenReg,
std::function<TreeGenerator* (
FeatureMap const& fmap, std::string attrs, bool with_stats)> > {
};
} // namespace xgboost
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::TreeGenReg);
} // namespace dmlc
namespace xgboost {
TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const& fmap,
bool with_stats) {
auto pos = attrs.find(':');
std::string name;
std::string params;
if (pos != std::string::npos) {
name = attrs.substr(0, pos);
params = attrs.substr(pos+1, attrs.length() - pos - 1);
// Eliminate all occurances of single quote string.
size_t pos = std::string::npos;
while ((pos = params.find('\'')) != std::string::npos) {
params.replace(pos, 1, "\"");
}
} else {
name = attrs;
}
auto *e = ::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown Model Builder:" << name;
}
auto p_io_builder = (e->body)(fmap, params, with_stats);
return p_io_builder;
}
#define XGBOOST_REGISTER_TREE_IO(UniqueId, Name) \
static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeGenReg& \
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
class TextGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
public:
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
static std::string kStatTemplate = ",cover={cover}";
std::string result = SuperT::Match(
kLeafTemplate,
{{"{tabs}", SuperT::Tabs(depth)},
{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{stats}", with_stats_ ?
SuperT::Match(kStatTemplate,
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
auto split_index = tree[nid].SplitIndex();
std::string result = SuperT::Match(
kIndicatorTemplate,
{{"{nid}", std::to_string(nid)},
{"{fname}", fmap_.Name(split_index)},
{"{yes}", std::to_string(nyes)},
{"{no}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string SplitNodeImpl(
RegTree const& tree, int32_t nid, std::string const& template_str,
std::string cond, uint32_t depth) {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
{{"{tabs}", SuperT::Tabs(depth)},
{"{nid}", std::to_string(nid)},
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
std::to_string(split_index)},
{"{cond}", cond},
{"{left}", std::to_string(tree[nid].LeftChild())},
{"{right}", std::to_string(tree[nid].RightChild())},
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kIntegerTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
const bst_float floored = std::floor(cond);
const int32_t integer_threshold
= (floored == cond) ? static_cast<int>(floored)
: static_cast<int>(floored) + 1;
return SplitNodeImpl(tree, nid, kIntegerTemplate,
std::to_string(integer_threshold), depth);
}
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kQuantitiveTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}
std::string NodeStat(RegTree const& tree, int32_t nid) override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
return result;
}
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
if (tree[nid].IsLeaf()) {
return this->LeafNode(tree, nid, depth);
}
static std::string const kNodeTemplate = "{parent}{stat}\n{left}\n{right}";
auto result = SuperT::Match(
kNodeTemplate,
{{"{parent}", this->SplitNode(tree, nid, depth)},
{"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
return result;
}
void BuildTree(RegTree const& tree) override {
static std::string const& kTreeTemplate = "{nodes}\n";
auto result = SuperT::Match(
kTreeTemplate,
{{"{nodes}", this->BuildTree(tree, 0, 0)}});
ss_ << result;
}
};
XGBOOST_REGISTER_TREE_IO(TextGenerator, "text")
.describe("Dump text representation of tree")
.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) {
return new TextGenerator(fmap, attrs, with_stats);
});
class JsonGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
public:
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}
std::string Indent(uint32_t depth) {
std::string result;
for (uint32_t i = 0; i < depth + 1; ++i) {
result += " ";
}
return result;
}
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kLeafTemplate =
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
static std::string const kStatTemplate =
R"S(, "cover": {sum_hess} )S";
std::string result = SuperT::Match(
kLeafTemplate,
{{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{stat}", with_stats_ ? SuperT::Match(
kStatTemplate,
{{"{sum_hess}",
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
static std::string const kIndicatorTemplate =
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID";
auto split_index = tree[nid].SplitIndex();
auto result = SuperT::Match(
kIndicatorTemplate,
{{"{nid}", std::to_string(nid)},
{"{depth}", std::to_string(depth)},
{"{fname}", fmap_.Name(split_index)},
{"{yes}", std::to_string(nyes)},
{"{no}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string SplitNodeImpl(RegTree const& tree, int32_t nid,
std::string const& template_str, std::string cond, uint32_t depth) {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
{{"{nid}", std::to_string(nid)},
{"{depth}", std::to_string(depth)},
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
std::to_string(split_index)},
{"{cond}", cond},
{"{left}", std::to_string(tree[nid].LeftChild())},
{"{right}", std::to_string(tree[nid].RightChild())},
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto cond = tree[nid].SplitCond();
const bst_float floored = std::floor(cond);
const int32_t integer_threshold
= (floored == cond) ? static_cast<int32_t>(floored)
: static_cast<int32_t>(floored) + 1;
static std::string const kIntegerTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
return SplitNodeImpl(tree, nid, kIntegerTemplate,
std::to_string(integer_threshold), depth);
}
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kQuantitiveTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
bst_float cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}
std::string NodeStat(RegTree const& tree, int32_t nid) override {
static std::string kStatTemplate =
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
auto result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
return result;
}
std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string properties = SuperT::SplitNode(tree, nid, depth);
static std::string const kSplitNodeTemplate =
"{{properties} {stat}, \"children\": [{left}, {right}\n{indent}]}";
auto result = SuperT::Match(
kSplitNodeTemplate,
{{"{properties}", properties},
{"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)},
{"{indent}", this->Indent(depth)}});
return result;
}
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kNodeTemplate = "{newline}{indent}{nodes}";
auto result = SuperT::Match(
kNodeTemplate,
{{"{newline}", depth == 0 ? "" : "\n"},
{"{indent}", Indent(depth)},
{"{nodes}", tree[nid].IsLeaf() ? this->LeafNode(tree, nid, depth) :
this->SplitNode(tree, nid, depth)}});
return result;
}
};
XGBOOST_REGISTER_TREE_IO(JsonGenerator, "json")
.describe("Dump json representation of tree")
.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) {
return new JsonGenerator(fmap, attrs, with_stats);
});
struct GraphvizParam : public dmlc::Parameter<GraphvizParam> {
std::string yes_color;
std::string no_color;
std::string rankdir;
std::string condition_node_params;
std::string leaf_node_params;
std::string graph_attrs;
DMLC_DECLARE_PARAMETER(GraphvizParam){
DMLC_DECLARE_FIELD(yes_color)
.set_default("#0000FF")
.describe("Edge color when meets the node condition.");
DMLC_DECLARE_FIELD(no_color)
.set_default("#FF0000")
.describe("Edge color when doesn't meet the node condition.");
DMLC_DECLARE_FIELD(rankdir)
.set_default("TB")
.describe("Passed to graphiz via graph_attr.");
DMLC_DECLARE_FIELD(condition_node_params)
.set_default("")
.describe("Conditional node configuration");
DMLC_DECLARE_FIELD(leaf_node_params)
.set_default("")
.describe("Leaf node configuration");
DMLC_DECLARE_FIELD(graph_attrs)
.set_default("")
.describe("Any other extra attributes for graphviz `graph_attr`.");
}
};
DMLC_REGISTER_PARAMETER(GraphvizParam);
class GraphvizGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
std::stringstream& ss_;
GraphvizParam param_;
public:
GraphvizGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats), ss_{SuperT::ss_} {
param_.InitAllowUnknown(std::map<std::string, std::string>{});
using KwArg = std::map<std::string, std::map<std::string, std::string>>;
KwArg kwargs;
if (attrs.length() != 0) {
std::istringstream iss(attrs);
try {
dmlc::JSONReader reader(&iss);
reader.Read(&kwargs);
} catch(dmlc::Error const& e) {
LOG(FATAL) << "Failed to parse graphviz parameters:\n\t"
<< attrs << "\n"
<< "With error:\n"
<< e.what();
} }
} }
if (with_stats) { // This turns out to be tricky, as `dmlc::Parameter::Load(JSONReader*)` doesn't
if (format == "json") { // support loading nested json objects.
fo << ", \"gain\": " << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg if (kwargs.find("condition_node_params") != kwargs.cend()) {
<< ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess; auto const& cnp = kwargs["condition_node_params"];
} else { for (auto const& kv : cnp) {
fo << ",gain=" << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg param_.condition_node_params += kv.first + '=' + "\"" + kv.second + "\" ";
<< ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
} }
kwargs.erase("condition_node_params");
} }
if (format == "json") { if (kwargs.find("leaf_node_params") != kwargs.cend()) {
fo << ", \"children\": ["; auto const& lnp = kwargs["leaf_node_params"];
} else { for (auto const& kv : lnp) {
fo << '\n'; param_.leaf_node_params += kv.first + '=' + "\"" + kv.second + "\" ";
}
DumpRegTree(fo, tree, fmap, tree[nid].LeftChild(), depth + 1, false, with_stats, format);
DumpRegTree(fo, tree, fmap, tree[nid].RightChild(), depth + 1, true, with_stats, format);
if (format == "json") {
fo << std::endl;
for (int i = 0; i < depth + 1; ++i) {
fo << " ";
} }
fo << "]}"; kwargs.erase("leaf_node_params");
}
if (kwargs.find("edge") != kwargs.cend()) {
if (kwargs["edge"].find("yes_color") != kwargs["edge"].cend()) {
param_.yes_color = kwargs["edge"]["yes_color"];
}
if (kwargs["edge"].find("no_color") != kwargs["edge"].cend()) {
param_.no_color = kwargs["edge"]["no_color"];
}
kwargs.erase("edge");
}
auto const& extra = kwargs["graph_attrs"];
static std::string const kGraphTemplate = " graph [ {key}=\"{value}\" ]\n";
for (auto const& kv : extra) {
param_.graph_attrs += SuperT::Match(kGraphTemplate,
{{"{key}", kv.first},
{"{value}", kv.second}});
}
kwargs.erase("graph_attrs");
if (kwargs.size() != 0) {
std::stringstream ss;
ss << "The following parameters for graphviz are not recognized:\n";
for (auto kv : kwargs) {
ss << kv.first << ", ";
}
LOG(WARNING) << ss.str();
} }
} }
}
protected:
// Only indicator is different, so we combine all different node types into this
// function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto split = tree[nid].SplitIndex();
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
" {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
// Indicator only has fname.
bool has_less = (split >= fmap_.Size()) || fmap_.type(split) != FeatureMap::kIndicator;
std::string result = SuperT::Match(kNodeTemplate, {
{"{nid}", std::to_string(nid)},
{"{fname}", split < fmap_.Size() ? fmap_.Name(split) :
'f' + std::to_string(split)},
{"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
{"{params}", param_.condition_node_params}});
static std::string const kEdgeTemplate =
" {nid} -> {child} [label=\"{is_missing}\" color=\"{color}\"]\n";
auto MatchFn = SuperT::Match; // mingw failed to capture protected fn.
auto BuildEdge =
[&tree, nid, MatchFn, this](int32_t child) {
bool is_missing = tree[nid].DefaultChild() == child;
std::string buffer = MatchFn(kEdgeTemplate, {
{"{nid}", std::to_string(nid)},
{"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{is_missing}", is_missing ? "yes, missing": "no"}});
return buffer;
};
result += BuildEdge(tree[nid].LeftChild());
result += BuildEdge(tree[nid].RightChild());
return result;
};
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
auto result = SuperT::Match(kLeafTemplate, {
{"{nid}", std::to_string(nid)},
{"{leaf-value}", ToStr(tree[nid].LeafValue())},
{"{params}", param_.leaf_node_params}});
return result;
};
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
if (tree[nid].IsLeaf()) {
return this->LeafNode(tree, nid, depth);
}
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
auto result = SuperT::Match(
kNodeTemplate,
{{"{parent}", this->PlainNode(tree, nid, depth)},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
return result;
}
void BuildTree(RegTree const& tree) override {
static std::string const kTreeTemplate =
"digraph {\n"
" graph [ rankdir={rankdir} ]\n"
"{graph_attrs}\n"
"{nodes}}";
auto result = SuperT::Match(
kTreeTemplate,
{{"{rankdir}", param_.rankdir},
{"{graph_attrs}", param_.graph_attrs},
{"{nodes}", this->BuildTree(tree, 0, 0)}});
ss_ << result;
};
};
XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
.describe("Dump graphviz representation of tree")
.set_body([](FeatureMap const& fmap, std::string attrs, bool with_stats) {
return new GraphvizGenerator(fmap, attrs, with_stats);
});
std::string RegTree::DumpModel(const FeatureMap& fmap, std::string RegTree::DumpModel(const FeatureMap& fmap,
bool with_stats, bool with_stats,
std::string format) const { std::string format) const {
std::stringstream fo(""); std::unique_ptr<TreeGenerator> builder {
for (int i = 0; i < param.num_roots; ++i) { TreeGenerator::Create(format, fmap, with_stats)
DumpRegTree(fo, *this, fmap, i, 0, false, with_stats, format); };
for (int32_t i = 0; i < param.num_roots; ++i) {
builder->BuildTree(*this);
} }
return fo.str();
std::string result = builder->Str();
return result;
} }
void RegTree::FillNodeMeanValues() { void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes; size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() == num_nodes) { if (this->node_mean_values_.size() == num_nodes) {

View File

@ -144,7 +144,7 @@ __global__ void CubScanByKeyL1(
int previousKey = __shfl_up_sync(0xFFFFFFFF, myKey, 1); int previousKey = __shfl_up_sync(0xFFFFFFFF, myKey, 1);
#else #else
int previousKey = __shfl_up(myKey, 1); int previousKey = __shfl_up(myKey, 1);
#endif #endif // (__CUDACC_VER_MAJOR__ >= 9)
// Collectively compute the block-wide exclusive prefix sum // Collectively compute the block-wide exclusive prefix sum
BlockScan(temp_storage) BlockScan(temp_storage)
.ExclusiveScan(threadData, threadData, rootPair, AddByKey()); .ExclusiveScan(threadData, threadData, rootPair, AddByKey());

View File

@ -101,4 +101,121 @@ TEST(Tree, AllocateNode) {
ASSERT_TRUE(nodes.at(1).IsLeaf()); ASSERT_TRUE(nodes.at(1).IsLeaf());
ASSERT_TRUE(nodes.at(2).IsLeaf()); ASSERT_TRUE(nodes.at(2).IsLeaf());
} }
RegTree ConstructTree() {
RegTree tree;
tree.ExpandNode(
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
/*default_left=*/true,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
return tree;
}
TEST(Tree, DumpJson) {
auto tree = ConstructTree();
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, "json");
size_t n_leaves = 0;
size_t iter = 0;
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
size_t n_conditions = 0;
iter = 0;
while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
n_conditions++;
}
ASSERT_EQ(n_conditions, 3);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, "json");
ASSERT_NE(str.find(R"("split": "feat_0")"), std::string::npos);
ASSERT_NE(str.find(R"("split": "feat_1")"), std::string::npos);
ASSERT_NE(str.find(R"("split": "feat_2")"), std::string::npos);
str = tree.DumpModel(fmap, false, "json");
ASSERT_EQ(str.find("cover"), std::string::npos);
}
TEST(Tree, DumpText) {
auto tree = ConstructTree();
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, "text");
size_t n_leaves = 0;
size_t iter = 0;
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
iter = 0;
size_t n_conditions = 0;
while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
n_conditions++;
}
ASSERT_EQ(n_conditions, 3);
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
ASSERT_NE(str.find("[f2<2]"), std::string::npos);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, "text");
ASSERT_NE(str.find("[feat_0]"), std::string::npos);
ASSERT_NE(str.find("[feat_1<1]"), std::string::npos);
ASSERT_NE(str.find("[feat_2<2]"), std::string::npos);
str = tree.DumpModel(fmap, false, "text");
ASSERT_EQ(str.find("cover"), std::string::npos);
}
TEST(Tree, DumpDot) {
auto tree = ConstructTree();
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, "dot");
size_t n_leaves = 0;
size_t iter = 0;
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
size_t n_edges = 0;
iter = 0;
while ((iter = str.find("->", iter + 1)) != std::string::npos) {
n_edges++;
}
ASSERT_EQ(n_edges, 6);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, "dot");
ASSERT_NE(str.find(R"("feat_0")"), std::string::npos);
ASSERT_NE(str.find(R"(feat_1<1)"), std::string::npos);
ASSERT_NE(str.find(R"(feat_2<2)"), std::string::npos);
str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})");
ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos);
}
} // namespace xgboost } // namespace xgboost

View File

@ -10,7 +10,7 @@ try:
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
from matplotlib.axes import Axes from matplotlib.axes import Axes
from graphviz import Digraph from graphviz import Source
except ImportError: except ImportError:
pass pass
@ -57,7 +57,7 @@ class TestPlotting(unittest.TestCase):
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
g = xgb.to_graphviz(bst2, num_trees=0) g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Digraph) assert isinstance(g, Source)
ax = xgb.plot_tree(bst2, num_trees=0) ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes) assert isinstance(ax, Axes)

View File

@ -87,7 +87,6 @@ class TestSHAP(unittest.TestCase):
r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)" r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)"
r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)" r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)"
for tree in model.get_dump(with_stats=True): for tree in model.get_dump(with_stats=True):
lines = list(tree.splitlines()) lines = list(tree.splitlines())
trees.append([None for i in range(len(lines))]) trees.append([None for i in range(len(lines))])
for line in lines: for line in lines:

View File

@ -352,7 +352,7 @@ def test_sklearn_plotting():
matplotlib.use('Agg') matplotlib.use('Agg')
from matplotlib.axes import Axes from matplotlib.axes import Axes
from graphviz import Digraph from graphviz import Source
ax = xgb.plot_importance(classifier) ax = xgb.plot_importance(classifier)
assert isinstance(ax, Axes) assert isinstance(ax, Axes)
@ -362,7 +362,7 @@ def test_sklearn_plotting():
assert len(ax.patches) == 4 assert len(ax.patches) == 4
g = xgb.to_graphviz(classifier, num_trees=0) g = xgb.to_graphviz(classifier, num_trees=0)
assert isinstance(g, Digraph) assert isinstance(g, Source)
ax = xgb.plot_tree(classifier, num_trees=0) ax = xgb.plot_tree(classifier, num_trees=0)
assert isinstance(ax, Axes) assert isinstance(ax, Axes)
@ -641,7 +641,8 @@ def test_XGBClassifier_resume():
X, Y = load_breast_cancer(return_X_y=True) X, Y = load_breast_cancer(return_X_y=True)
model1 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8) model1 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model1.fit(X, Y) model1.fit(X, Y)
pred1 = model1.predict(X) pred1 = model1.predict(X)
@ -649,7 +650,8 @@ def test_XGBClassifier_resume():
# file name of stored xgb model # file name of stored xgb model
model1.save_model(model1_path) model1.save_model(model1_path)
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8) model2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_path) model2.fit(X, Y, xgb_model=model1_path)
pred2 = model2.predict(X) pred2 = model2.predict(X)
@ -660,7 +662,8 @@ def test_XGBClassifier_resume():
# file name of 'Booster' instance Xgb model # file name of 'Booster' instance Xgb model
model1.get_booster().save_model(model1_booster_path) model1.get_booster().save_model(model1_booster_path)
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8) model2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_booster_path) model2.fit(X, Y, xgb_model=model1_booster_path)
pred2 = model2.predict(X) pred2 = model2.predict(X)