Implement tree model dump with code generator. (#4602)

* Implement tree model dump with a code generator.

* Split up generators.
* Implement graphviz generator.
* Use pattern matching.

* [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package.


Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2019-06-26 15:20:44 +08:00
committed by GitHub
parent fe2de6f415
commit 8bdf15120a
11 changed files with 802 additions and 264 deletions

View File

@@ -1033,19 +1033,21 @@ inline void XGBoostDumpModelImpl(
*out_models = dmlc::BeginPtr(charp_vecs);
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
}
XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
const char* fmap,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
const char* fmap,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
return XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models);
}
XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
const char* fmap,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
const char* fmap,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
API_BEGIN();
CHECK_HANDLE();
FeatureMap featmap;
@@ -1060,23 +1062,24 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
}
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int fnum,
const char** fname,
const char** ftype,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
int fnum,
const char** fname,
const char** ftype,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats,
"text", len, out_models);
"text", len, out_models);
}
XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
int fnum,
const char** fname,
const char** ftype,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
int fnum,
const char** fname,
const char** ftype,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
API_BEGIN();
CHECK_HANDLE();
FeatureMap featmap;

View File

@@ -10,7 +10,7 @@
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
#include <nvToolsExt.h>
#endif
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
namespace xgboost {
namespace common {
@@ -98,7 +98,7 @@ struct Monitor {
stats.timer.Start();
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
stats.nvtx_id = nvtxRangeStartA(name.c_str());
#endif
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
}
}
void StopCuda(const std::string &name) {
@@ -108,7 +108,7 @@ struct Monitor {
stats.count++;
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
nvtxRangeEnd(stats.nvtx_id);
#endif
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
}
}
};

View File

@@ -1,14 +1,19 @@
/*!
* Copyright 2015 by Contributors
* Copyright 2015-2019 by Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
#include <dmlc/registry.h>
#include <dmlc/json.h>
#include <xgboost/tree_model.h>
#include <xgboost/logging.h>
#include <sstream>
#include <limits>
#include <cmath>
#include <iomanip>
#include "./param.h"
#include "param.h"
namespace xgboost {
// register tree parameter
@@ -17,158 +22,602 @@ DMLC_REGISTER_PARAMETER(TreeParam);
namespace tree {
DMLC_REGISTER_PARAMETER(TrainParam);
}
// internal function to dump regression tree to text
void DumpRegTree(std::stringstream& fo, // NOLINT(*)
const RegTree& tree,
const FeatureMap& fmap,
int nid, int depth, int add_comma,
bool with_stats, std::string format) {
int float_max_precision = std::numeric_limits<bst_float>::max_digits10;
if (format == "json") {
if (add_comma) {
fo << ",";
}
if (depth != 0) {
fo << std::endl;
}
for (int i = 0; i < depth + 1; ++i) {
fo << " ";
}
} else {
for (int i = 0; i < depth; ++i) {
fo << '\t';
}
/*!
* \brief Base class for dump model implementation, modeling closely after code generator.
*/
class TreeGenerator {
protected:
static int32_t constexpr kFloatMaxPrecision =
std::numeric_limits<bst_float>::max_digits10;
FeatureMap const& fmap_;
std::stringstream ss_;
bool const with_stats_;
template <typename Float>
static std::string ToStr(Float value) {
static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values.");
std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision) << value;
return ss.str();
}
if (tree[nid].IsLeaf()) {
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"leaf\": " << std::setprecision(float_max_precision) << tree[nid].LeafValue();
if (with_stats) {
fo << ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
}
fo << " }";
} else {
fo << nid << ":leaf=" << std::setprecision(float_max_precision) << tree[nid].LeafValue();
if (with_stats) {
fo << ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
}
fo << '\n';
static std::string Tabs(uint32_t n) {
std::string res;
for (uint32_t i = 0; i < n; ++i) {
res += '\t';
}
} else {
// right then left,
bst_float cond = tree[nid].SplitCond();
const unsigned split_index = tree[nid].SplitIndex();
if (split_index < fmap.Size()) {
switch (fmap.type(split_index)) {
return res;
}
/* \brief Find the first occurance of key in input and replace it with corresponding
* value.
*/
static std::string Match(std::string const& input,
std::map<std::string, std::string> const& replacements) {
std::string result = input;
for (auto const& kv : replacements) {
auto pos = result.find(kv.first);
CHECK_NE(pos, std::string::npos);
result.replace(pos, kv.first.length(), kv.second);
}
return result;
}
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) {
return "";
}
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) {
return "";
}
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) {
return "";
}
virtual std::string NodeStat(RegTree const& tree, int32_t nid) {
return "";
}
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
auto const split_index = tree[nid].SplitIndex();
std::string result;
if (split_index < fmap_.Size()) {
switch (fmap_.type(split_index)) {
case FeatureMap::kIndicator: {
int nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
<< ", \"yes\": " << nyes
<< ", \"no\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[" << fmap.Name(split_index) << "] yes=" << nyes
<< ",no=" << tree[nid].DefaultChild();
}
result = this->Indicator(tree, nid, depth);
break;
}
case FeatureMap::kInteger: {
const bst_float floored = std::floor(cond);
const int integer_threshold
= (floored == cond) ? static_cast<int>(floored)
: static_cast<int>(floored) + 1;
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
<< ", \"split_condition\": " << integer_threshold
<< ", \"yes\": " << tree[nid].LeftChild()
<< ", \"no\": " << tree[nid].RightChild()
<< ", \"missing\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[" << fmap.Name(split_index) << "<"
<< integer_threshold
<< "] yes=" << tree[nid].LeftChild()
<< ",no=" << tree[nid].RightChild()
<< ",missing=" << tree[nid].DefaultChild();
}
result = this->Integer(tree, nid, depth);
break;
}
case FeatureMap::kFloat:
case FeatureMap::kQuantitive: {
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
<< ", \"split_condition\": " << std::setprecision(float_max_precision) << cond
<< ", \"yes\": " << tree[nid].LeftChild()
<< ", \"no\": " << tree[nid].RightChild()
<< ", \"missing\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[" << fmap.Name(split_index)
<< "<" << std::setprecision(float_max_precision) << cond
<< "] yes=" << tree[nid].LeftChild()
<< ",no=" << tree[nid].RightChild()
<< ",missing=" << tree[nid].DefaultChild();
}
result = this->Quantitive(tree, nid, depth);
break;
}
default: LOG(FATAL) << "unknown fmap type";
}
default:
LOG(FATAL) << "Unknown feature map type.";
}
} else {
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": " << split_index
<< ", \"split_condition\": " << std::setprecision(float_max_precision) << cond
<< ", \"yes\": " << tree[nid].LeftChild()
<< ", \"no\": " << tree[nid].RightChild()
<< ", \"missing\": " << tree[nid].DefaultChild();
} else {
fo << nid << ":[f" << split_index << "<"<< std::setprecision(float_max_precision) << cond
<< "] yes=" << tree[nid].LeftChild()
<< ",no=" << tree[nid].RightChild()
<< ",missing=" << tree[nid].DefaultChild();
result = this->PlainNode(tree, nid, depth);
}
return result;
}
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
public:
TreeGenerator(FeatureMap const& _fmap, bool with_stats) :
fmap_{_fmap}, with_stats_{with_stats} {}
virtual ~TreeGenerator() = default;
virtual void BuildTree(RegTree const& tree) {
ss_ << this->BuildTree(tree, 0, 0);
}
std::string Str() const {
return ss_.str();
}
static TreeGenerator* Create(std::string const& attrs, FeatureMap const& fmap,
bool with_stats);
};
struct TreeGenReg : public dmlc::FunctionRegEntryBase<
TreeGenReg,
std::function<TreeGenerator* (
FeatureMap const& fmap, std::string attrs, bool with_stats)> > {
};
} // namespace xgboost
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::TreeGenReg);
} // namespace dmlc
namespace xgboost {
TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const& fmap,
bool with_stats) {
auto pos = attrs.find(':');
std::string name;
std::string params;
if (pos != std::string::npos) {
name = attrs.substr(0, pos);
params = attrs.substr(pos+1, attrs.length() - pos - 1);
// Eliminate all occurances of single quote string.
size_t pos = std::string::npos;
while ((pos = params.find('\'')) != std::string::npos) {
params.replace(pos, 1, "\"");
}
} else {
name = attrs;
}
auto *e = ::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown Model Builder:" << name;
}
auto p_io_builder = (e->body)(fmap, params, with_stats);
return p_io_builder;
}
#define XGBOOST_REGISTER_TREE_IO(UniqueId, Name) \
static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeGenReg& \
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
class TextGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
public:
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
static std::string kStatTemplate = ",cover={cover}";
std::string result = SuperT::Match(
kLeafTemplate,
{{"{tabs}", SuperT::Tabs(depth)},
{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{stats}", with_stats_ ?
SuperT::Match(kStatTemplate,
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
auto split_index = tree[nid].SplitIndex();
std::string result = SuperT::Match(
kIndicatorTemplate,
{{"{nid}", std::to_string(nid)},
{"{fname}", fmap_.Name(split_index)},
{"{yes}", std::to_string(nyes)},
{"{no}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string SplitNodeImpl(
RegTree const& tree, int32_t nid, std::string const& template_str,
std::string cond, uint32_t depth) {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
{{"{tabs}", SuperT::Tabs(depth)},
{"{nid}", std::to_string(nid)},
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
std::to_string(split_index)},
{"{cond}", cond},
{"{left}", std::to_string(tree[nid].LeftChild())},
{"{right}", std::to_string(tree[nid].RightChild())},
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kIntegerTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
const bst_float floored = std::floor(cond);
const int32_t integer_threshold
= (floored == cond) ? static_cast<int>(floored)
: static_cast<int>(floored) + 1;
return SplitNodeImpl(tree, nid, kIntegerTemplate,
std::to_string(integer_threshold), depth);
}
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kQuantitiveTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}
std::string NodeStat(RegTree const& tree, int32_t nid) override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
return result;
}
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
if (tree[nid].IsLeaf()) {
return this->LeafNode(tree, nid, depth);
}
static std::string const kNodeTemplate = "{parent}{stat}\n{left}\n{right}";
auto result = SuperT::Match(
kNodeTemplate,
{{"{parent}", this->SplitNode(tree, nid, depth)},
{"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
return result;
}
void BuildTree(RegTree const& tree) override {
static std::string const& kTreeTemplate = "{nodes}\n";
auto result = SuperT::Match(
kTreeTemplate,
{{"{nodes}", this->BuildTree(tree, 0, 0)}});
ss_ << result;
}
};
XGBOOST_REGISTER_TREE_IO(TextGenerator, "text")
.describe("Dump text representation of tree")
.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) {
return new TextGenerator(fmap, attrs, with_stats);
});
class JsonGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
public:
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}
std::string Indent(uint32_t depth) {
std::string result;
for (uint32_t i = 0; i < depth + 1; ++i) {
result += " ";
}
return result;
}
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kLeafTemplate =
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
static std::string const kStatTemplate =
R"S(, "cover": {sum_hess} )S";
std::string result = SuperT::Match(
kLeafTemplate,
{{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{stat}", with_stats_ ? SuperT::Match(
kStatTemplate,
{{"{sum_hess}",
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
static std::string const kIndicatorTemplate =
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID";
auto split_index = tree[nid].SplitIndex();
auto result = SuperT::Match(
kIndicatorTemplate,
{{"{nid}", std::to_string(nid)},
{"{depth}", std::to_string(depth)},
{"{fname}", fmap_.Name(split_index)},
{"{yes}", std::to_string(nyes)},
{"{no}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string SplitNodeImpl(RegTree const& tree, int32_t nid,
std::string const& template_str, std::string cond, uint32_t depth) {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
{{"{nid}", std::to_string(nid)},
{"{depth}", std::to_string(depth)},
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
std::to_string(split_index)},
{"{cond}", cond},
{"{left}", std::to_string(tree[nid].LeftChild())},
{"{right}", std::to_string(tree[nid].RightChild())},
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
return result;
}
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto cond = tree[nid].SplitCond();
const bst_float floored = std::floor(cond);
const int32_t integer_threshold
= (floored == cond) ? static_cast<int32_t>(floored)
: static_cast<int32_t>(floored) + 1;
static std::string const kIntegerTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
return SplitNodeImpl(tree, nid, kIntegerTemplate,
std::to_string(integer_threshold), depth);
}
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kQuantitiveTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
bst_float cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}
std::string NodeStat(RegTree const& tree, int32_t nid) override {
static std::string kStatTemplate =
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
auto result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
return result;
}
std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string properties = SuperT::SplitNode(tree, nid, depth);
static std::string const kSplitNodeTemplate =
"{{properties} {stat}, \"children\": [{left}, {right}\n{indent}]}";
auto result = SuperT::Match(
kSplitNodeTemplate,
{{"{properties}", properties},
{"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)},
{"{indent}", this->Indent(depth)}});
return result;
}
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kNodeTemplate = "{newline}{indent}{nodes}";
auto result = SuperT::Match(
kNodeTemplate,
{{"{newline}", depth == 0 ? "" : "\n"},
{"{indent}", Indent(depth)},
{"{nodes}", tree[nid].IsLeaf() ? this->LeafNode(tree, nid, depth) :
this->SplitNode(tree, nid, depth)}});
return result;
}
};
XGBOOST_REGISTER_TREE_IO(JsonGenerator, "json")
.describe("Dump json representation of tree")
.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) {
return new JsonGenerator(fmap, attrs, with_stats);
});
struct GraphvizParam : public dmlc::Parameter<GraphvizParam> {
std::string yes_color;
std::string no_color;
std::string rankdir;
std::string condition_node_params;
std::string leaf_node_params;
std::string graph_attrs;
DMLC_DECLARE_PARAMETER(GraphvizParam){
DMLC_DECLARE_FIELD(yes_color)
.set_default("#0000FF")
.describe("Edge color when meets the node condition.");
DMLC_DECLARE_FIELD(no_color)
.set_default("#FF0000")
.describe("Edge color when doesn't meet the node condition.");
DMLC_DECLARE_FIELD(rankdir)
.set_default("TB")
.describe("Passed to graphiz via graph_attr.");
DMLC_DECLARE_FIELD(condition_node_params)
.set_default("")
.describe("Conditional node configuration");
DMLC_DECLARE_FIELD(leaf_node_params)
.set_default("")
.describe("Leaf node configuration");
DMLC_DECLARE_FIELD(graph_attrs)
.set_default("")
.describe("Any other extra attributes for graphviz `graph_attr`.");
}
};
DMLC_REGISTER_PARAMETER(GraphvizParam);
class GraphvizGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
std::stringstream& ss_;
GraphvizParam param_;
public:
GraphvizGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats), ss_{SuperT::ss_} {
param_.InitAllowUnknown(std::map<std::string, std::string>{});
using KwArg = std::map<std::string, std::map<std::string, std::string>>;
KwArg kwargs;
if (attrs.length() != 0) {
std::istringstream iss(attrs);
try {
dmlc::JSONReader reader(&iss);
reader.Read(&kwargs);
} catch(dmlc::Error const& e) {
LOG(FATAL) << "Failed to parse graphviz parameters:\n\t"
<< attrs << "\n"
<< "With error:\n"
<< e.what();
}
}
if (with_stats) {
if (format == "json") {
fo << ", \"gain\": " << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg
<< ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
} else {
fo << ",gain=" << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg
<< ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
// This turns out to be tricky, as `dmlc::Parameter::Load(JSONReader*)` doesn't
// support loading nested json objects.
if (kwargs.find("condition_node_params") != kwargs.cend()) {
auto const& cnp = kwargs["condition_node_params"];
for (auto const& kv : cnp) {
param_.condition_node_params += kv.first + '=' + "\"" + kv.second + "\" ";
}
kwargs.erase("condition_node_params");
}
if (format == "json") {
fo << ", \"children\": [";
} else {
fo << '\n';
}
DumpRegTree(fo, tree, fmap, tree[nid].LeftChild(), depth + 1, false, with_stats, format);
DumpRegTree(fo, tree, fmap, tree[nid].RightChild(), depth + 1, true, with_stats, format);
if (format == "json") {
fo << std::endl;
for (int i = 0; i < depth + 1; ++i) {
fo << " ";
if (kwargs.find("leaf_node_params") != kwargs.cend()) {
auto const& lnp = kwargs["leaf_node_params"];
for (auto const& kv : lnp) {
param_.leaf_node_params += kv.first + '=' + "\"" + kv.second + "\" ";
}
fo << "]}";
kwargs.erase("leaf_node_params");
}
if (kwargs.find("edge") != kwargs.cend()) {
if (kwargs["edge"].find("yes_color") != kwargs["edge"].cend()) {
param_.yes_color = kwargs["edge"]["yes_color"];
}
if (kwargs["edge"].find("no_color") != kwargs["edge"].cend()) {
param_.no_color = kwargs["edge"]["no_color"];
}
kwargs.erase("edge");
}
auto const& extra = kwargs["graph_attrs"];
static std::string const kGraphTemplate = " graph [ {key}=\"{value}\" ]\n";
for (auto const& kv : extra) {
param_.graph_attrs += SuperT::Match(kGraphTemplate,
{{"{key}", kv.first},
{"{value}", kv.second}});
}
kwargs.erase("graph_attrs");
if (kwargs.size() != 0) {
std::stringstream ss;
ss << "The following parameters for graphviz are not recognized:\n";
for (auto kv : kwargs) {
ss << kv.first << ", ";
}
LOG(WARNING) << ss.str();
}
}
}
protected:
// Only indicator is different, so we combine all different node types into this
// function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
auto split = tree[nid].SplitIndex();
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
" {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
// Indicator only has fname.
bool has_less = (split >= fmap_.Size()) || fmap_.type(split) != FeatureMap::kIndicator;
std::string result = SuperT::Match(kNodeTemplate, {
{"{nid}", std::to_string(nid)},
{"{fname}", split < fmap_.Size() ? fmap_.Name(split) :
'f' + std::to_string(split)},
{"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
{"{params}", param_.condition_node_params}});
static std::string const kEdgeTemplate =
" {nid} -> {child} [label=\"{is_missing}\" color=\"{color}\"]\n";
auto MatchFn = SuperT::Match; // mingw failed to capture protected fn.
auto BuildEdge =
[&tree, nid, MatchFn, this](int32_t child) {
bool is_missing = tree[nid].DefaultChild() == child;
std::string buffer = MatchFn(kEdgeTemplate, {
{"{nid}", std::to_string(nid)},
{"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{is_missing}", is_missing ? "yes, missing": "no"}});
return buffer;
};
result += BuildEdge(tree[nid].LeftChild());
result += BuildEdge(tree[nid].RightChild());
return result;
};
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
auto result = SuperT::Match(kLeafTemplate, {
{"{nid}", std::to_string(nid)},
{"{leaf-value}", ToStr(tree[nid].LeafValue())},
{"{params}", param_.leaf_node_params}});
return result;
};
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
if (tree[nid].IsLeaf()) {
return this->LeafNode(tree, nid, depth);
}
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
auto result = SuperT::Match(
kNodeTemplate,
{{"{parent}", this->PlainNode(tree, nid, depth)},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
return result;
}
void BuildTree(RegTree const& tree) override {
static std::string const kTreeTemplate =
"digraph {\n"
" graph [ rankdir={rankdir} ]\n"
"{graph_attrs}\n"
"{nodes}}";
auto result = SuperT::Match(
kTreeTemplate,
{{"{rankdir}", param_.rankdir},
{"{graph_attrs}", param_.graph_attrs},
{"{nodes}", this->BuildTree(tree, 0, 0)}});
ss_ << result;
};
};
XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
.describe("Dump graphviz representation of tree")
.set_body([](FeatureMap const& fmap, std::string attrs, bool with_stats) {
return new GraphvizGenerator(fmap, attrs, with_stats);
});
std::string RegTree::DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const {
std::stringstream fo("");
for (int i = 0; i < param.num_roots; ++i) {
DumpRegTree(fo, *this, fmap, i, 0, false, with_stats, format);
std::unique_ptr<TreeGenerator> builder {
TreeGenerator::Create(format, fmap, with_stats)
};
for (int32_t i = 0; i < param.num_roots; ++i) {
builder->BuildTree(*this);
}
return fo.str();
std::string result = builder->Str();
return result;
}
void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() == num_nodes) {

View File

@@ -144,7 +144,7 @@ __global__ void CubScanByKeyL1(
int previousKey = __shfl_up_sync(0xFFFFFFFF, myKey, 1);
#else
int previousKey = __shfl_up(myKey, 1);
#endif
#endif // (__CUDACC_VER_MAJOR__ >= 9)
// Collectively compute the block-wide exclusive prefix sum
BlockScan(temp_storage)
.ExclusiveScan(threadData, threadData, rootPair, AddByKey());