Support graphviz plot for multi-target tree. (#10093)

This commit is contained in:
Jiaming Yuan
2024-03-09 05:35:25 +08:00
committed by GitHub
parent e14c3b9325
commit 2c13f90384
3 changed files with 133 additions and 63 deletions

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2015-2023, XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
@@ -8,6 +8,7 @@
#include <xgboost/json.h>
#include <xgboost/tree_model.h>
#include <array> // for array
#include <cmath>
#include <iomanip>
#include <limits>
@@ -15,7 +16,7 @@
#include <type_traits>
#include "../common/categorical.h"
#include "../common/common.h" // for EscapeU8
#include "../common/common.h" // for EscapeU8
#include "../predictor/predict_fn.h"
#include "io_utils.h" // for GetElem
#include "param.h"
@@ -31,26 +32,50 @@ namespace tree {
DMLC_REGISTER_PARAMETER(TrainParam);
}
namespace {
template <typename Float>
std::enable_if_t<std::is_floating_point_v<Float>, std::string> ToStr(Float value) {
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values.");
std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision) << value;
return ss.str();
}
template <typename Float>
std::string ToStr(linalg::VectorView<Float> value, bst_target_t limit) {
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values.");
std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision);
if (value.Size() == 1) {
ss << value(0);
return ss.str();
}
CHECK_GE(limit, 2);
auto n = std::min(static_cast<bst_target_t>(value.Size() - 1), limit - 1);
ss << "[";
for (std::size_t i = 0; i < n; ++i) {
ss << value(i) << ", ";
}
if (value.Size() > limit) {
ss << "..., ";
}
ss << value(value.Size() - 1) << "]";
return ss.str();
}
} // namespace
/*!
* \brief Base class for dump model implementation, modeling closely after code generator.
*/
class TreeGenerator {
protected:
static int32_t constexpr kFloatMaxPrecision =
std::numeric_limits<bst_float>::max_digits10;
FeatureMap const& fmap_;
std::stringstream ss_;
bool const with_stats_;
template <typename Float>
static std::string ToStr(Float value) {
static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values.");
std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision) << value;
return ss.str();
}
static std::string Tabs(uint32_t n) {
std::string res;
for (uint32_t i = 0; i < n; ++i) {
@@ -258,10 +283,10 @@ class TextGenerator : public TreeGenerator {
kLeafTemplate,
{{"{tabs}", SuperT::Tabs(depth)},
{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{leaf}", ToStr(tree[nid].LeafValue())},
{"{stats}", with_stats_ ?
SuperT::Match(kStatTemplate,
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
{{"{cover}", ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}
@@ -311,14 +336,14 @@ class TextGenerator : public TreeGenerator {
static std::string const kQuantitiveTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
}
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
}
std::string Categorical(RegTree const &tree, int32_t nid,
@@ -336,8 +361,8 @@ class TextGenerator : public TreeGenerator {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
return result;
}
@@ -393,11 +418,11 @@ class JsonGenerator : public TreeGenerator {
std::string result = SuperT::Match(
kLeafTemplate,
{{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{leaf}", ToStr(tree[nid].LeafValue())},
{"{stat}", with_stats_ ? SuperT::Match(
kStatTemplate,
{{"{sum_hess}",
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}
@@ -468,7 +493,7 @@ class JsonGenerator : public TreeGenerator {
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
bst_float cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
}
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
@@ -477,7 +502,7 @@ class JsonGenerator : public TreeGenerator {
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
}
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
@@ -485,8 +510,8 @@ class JsonGenerator : public TreeGenerator {
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
auto result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
return result;
}
@@ -622,11 +647,11 @@ class GraphvizGenerator : public TreeGenerator {
protected:
template <bool is_categorical>
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const {
std::string BuildEdge(RegTree const &tree, bst_node_t nidx, int32_t child, bool left) const {
static std::string const kEdgeTemplate =
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
// Is this the default child for missing value?
bool is_missing = tree[nid].DefaultChild() == child;
bool is_missing = tree.DefaultChild(nidx) == child;
std::string branch;
if (is_categorical) {
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
@@ -635,7 +660,7 @@ class GraphvizGenerator : public TreeGenerator {
}
std::string buffer =
SuperT::Match(kEdgeTemplate,
{{"{nid}", std::to_string(nid)},
{{"{nid}", std::to_string(nidx)},
{"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{branch}", branch}});
@@ -644,68 +669,77 @@ class GraphvizGenerator : public TreeGenerator {
// Only indicator is different, so we combine all different node types into this
// function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
auto split_index = tree[nid].SplitIndex();
auto cond = tree[nid].SplitCond();
std::string PlainNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
auto split_index = tree.SplitIndex(nidx);
auto cond = tree.SplitCond(nidx);
static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
bool has_less =
(split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator;
std::string result =
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)},
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
{"{cond}", has_less ? ToStr(cond) : ""},
{"{params}", param_.condition_node_params}});
result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);
result += BuildEdge<false>(tree, nidx, tree.LeftChild(nidx), true);
result += BuildEdge<false>(tree, nidx, tree.RightChild(nidx), false);
return result;
};
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override {
std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
static std::string const kLabelTemplate =
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
auto cats = GetSplitCategories(tree, nid);
auto cats = GetSplitCategories(tree, nidx);
auto cats_str = PrintCatsAsSet(cats);
auto split_index = tree[nid].SplitIndex();
auto split_index = tree.SplitIndex(nidx);
std::string result =
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)},
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{cond}", cats_str},
{"{params}", param_.condition_node_params}});
result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);
result += BuildEdge<true>(tree, nidx, tree.LeftChild(nidx), true);
result += BuildEdge<true>(tree, nidx, tree.RightChild(nidx), false);
return result;
}
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
auto result = SuperT::Match(kLeafTemplate, {
{"{nid}", std::to_string(nid)},
{"{leaf-value}", ToStr(tree[nid].LeafValue())},
{"{params}", param_.leaf_node_params}});
return result;
};
std::string LeafNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
// hardcoded limit to avoid dumping long arrays into dot graph.
bst_target_t constexpr kLimit{3};
if (tree.IsMultiTarget()) {
auto value = tree.GetMultiTargetTree()->LeafValue(nidx);
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
{"{leaf-value}", ToStr(value, kLimit)},
{"{params}", param_.leaf_node_params}});
return result;
} else {
auto value = tree[nidx].LeafValue();
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
{"{leaf-value}", ToStr(value)},
{"{params}", param_.leaf_node_params}});
return result;
}
}
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
if (tree[nid].IsLeaf()) {
return this->LeafNode(tree, nid, depth);
std::string BuildTree(RegTree const& tree, bst_node_t nidx, uint32_t depth) override {
if (tree.IsLeaf(nidx)) {
return this->LeafNode(tree, nidx, depth);
}
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical
? this->Categorical(tree, nid, depth)
: this->PlainNode(tree, nid, depth);
auto node = tree.GetSplitTypes()[nidx] == FeatureType::kCategorical
? this->Categorical(tree, nidx, depth)
: this->PlainNode(tree, nidx, depth);
auto result = SuperT::Match(
kNodeTemplate,
{{"{parent}", node},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
{"{left}", this->BuildTree(tree, tree.LeftChild(nidx), depth+1)},
{"{right}", this->BuildTree(tree, tree.RightChild(nidx), depth+1)}});
return result;
}
@@ -733,7 +767,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
constexpr bst_node_t RegTree::kRoot;
std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const {
CHECK(!IsMultiTarget());
if (this->IsMultiTarget() && format != "dot") {
LOG(FATAL) << format << " tree dump " << MTNotImplemented();
}
std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)};
builder->BuildTree(*this);