Support graphviz plot for multi-target tree. (#10093)

This commit is contained in:
Jiaming Yuan 2024-03-09 05:35:25 +08:00 committed by GitHub
parent e14c3b9325
commit 2c13f90384
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 133 additions and 63 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2014-2023 by Contributors * Copyright 2014-2024, XGBoost Contributors
* \file tree_model.h * \file tree_model.h
* \brief model structure for tree * \brief model structure for tree
* \author Tianqi Chen * \author Tianqi Chen
@ -688,6 +688,9 @@ class RegTree : public Model {
} }
return (*this)[nidx].DefaultLeft(); return (*this)[nidx].DefaultLeft();
} }
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
}
[[nodiscard]] bool IsRoot(bst_node_t nidx) const { [[nodiscard]] bool IsRoot(bst_node_t nidx) const {
if (IsMultiTarget()) { if (IsMultiTarget()) {
return nidx == kRoot; return nidx == kRoot;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2015-2023, XGBoost Contributors * Copyright 2015-2024, XGBoost Contributors
* \file tree_model.cc * \file tree_model.cc
* \brief model structure for tree * \brief model structure for tree
*/ */
@ -8,6 +8,7 @@
#include <xgboost/json.h> #include <xgboost/json.h>
#include <xgboost/tree_model.h> #include <xgboost/tree_model.h>
#include <array> // for array
#include <cmath> #include <cmath>
#include <iomanip> #include <iomanip>
#include <limits> #include <limits>
@ -31,25 +32,49 @@ namespace tree {
DMLC_REGISTER_PARAMETER(TrainParam); DMLC_REGISTER_PARAMETER(TrainParam);
} }
/*! namespace {
* \brief Base class for dump model implementation, modeling closely after code generator. template <typename Float>
*/ std::enable_if_t<std::is_floating_point_v<Float>, std::string> ToStr(Float value) {
class TreeGenerator { int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
protected:
static int32_t constexpr kFloatMaxPrecision =
std::numeric_limits<bst_float>::max_digits10;
FeatureMap const& fmap_;
std::stringstream ss_;
bool const with_stats_;
template <typename Float>
static std::string ToStr(Float value) {
static_assert(std::is_floating_point<Float>::value, static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values."); "Use std::to_string instead for non-floating point values.");
std::stringstream ss; std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision) << value; ss << std::setprecision(kFloatMaxPrecision) << value;
return ss.str(); return ss.str();
}
template <typename Float>
std::string ToStr(linalg::VectorView<Float> value, bst_target_t limit) {
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values.");
std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision);
if (value.Size() == 1) {
ss << value(0);
return ss.str();
} }
CHECK_GE(limit, 2);
auto n = std::min(static_cast<bst_target_t>(value.Size() - 1), limit - 1);
ss << "[";
for (std::size_t i = 0; i < n; ++i) {
ss << value(i) << ", ";
}
if (value.Size() > limit) {
ss << "..., ";
}
ss << value(value.Size() - 1) << "]";
return ss.str();
}
} // namespace
/*!
* \brief Base class for dump model implementation, modeling closely after code generator.
*/
class TreeGenerator {
protected:
FeatureMap const& fmap_;
std::stringstream ss_;
bool const with_stats_;
static std::string Tabs(uint32_t n) { static std::string Tabs(uint32_t n) {
std::string res; std::string res;
@ -258,10 +283,10 @@ class TextGenerator : public TreeGenerator {
kLeafTemplate, kLeafTemplate,
{{"{tabs}", SuperT::Tabs(depth)}, {{"{tabs}", SuperT::Tabs(depth)},
{"{nid}", std::to_string(nid)}, {"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, {"{leaf}", ToStr(tree[nid].LeafValue())},
{"{stats}", with_stats_ ? {"{stats}", with_stats_ ?
SuperT::Match(kStatTemplate, SuperT::Match(kStatTemplate,
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); {{"{cover}", ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result; return result;
} }
@ -311,14 +336,14 @@ class TextGenerator : public TreeGenerator {
static std::string const kQuantitiveTemplate = static std::string const kQuantitiveTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond(); auto cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
} }
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond(); auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate = static std::string const kNodeTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
} }
std::string Categorical(RegTree const &tree, int32_t nid, std::string Categorical(RegTree const &tree, int32_t nid,
@ -336,8 +361,8 @@ class TextGenerator : public TreeGenerator {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}"; static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match( std::string const result = SuperT::Match(
kStatTemplate, kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)}, {{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}); {"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
return result; return result;
} }
@ -393,11 +418,11 @@ class JsonGenerator : public TreeGenerator {
std::string result = SuperT::Match( std::string result = SuperT::Match(
kLeafTemplate, kLeafTemplate,
{{"{nid}", std::to_string(nid)}, {{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, {"{leaf}", ToStr(tree[nid].LeafValue())},
{"{stat}", with_stats_ ? SuperT::Match( {"{stat}", with_stats_ ? SuperT::Match(
kStatTemplate, kStatTemplate,
{{"{sum_hess}", {{"{sum_hess}",
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result; return result;
} }
@ -468,7 +493,7 @@ class JsonGenerator : public TreeGenerator {
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I"; R"I("missing": {missing})I";
bst_float cond = tree[nid].SplitCond(); bst_float cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
} }
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
@ -477,7 +502,7 @@ class JsonGenerator : public TreeGenerator {
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I" R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I"; R"I("missing": {missing})I";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
} }
std::string NodeStat(RegTree const& tree, int32_t nid) const override { std::string NodeStat(RegTree const& tree, int32_t nid) const override {
@ -485,8 +510,8 @@ class JsonGenerator : public TreeGenerator {
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S"; R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
auto result = SuperT::Match( auto result = SuperT::Match(
kStatTemplate, kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)}, {{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}); {"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
return result; return result;
} }
@ -622,11 +647,11 @@ class GraphvizGenerator : public TreeGenerator {
protected: protected:
template <bool is_categorical> template <bool is_categorical>
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const { std::string BuildEdge(RegTree const &tree, bst_node_t nidx, int32_t child, bool left) const {
static std::string const kEdgeTemplate = static std::string const kEdgeTemplate =
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n"; " {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
// Is this the default child for missing value? // Is this the default child for missing value?
bool is_missing = tree[nid].DefaultChild() == child; bool is_missing = tree.DefaultChild(nidx) == child;
std::string branch; std::string branch;
if (is_categorical) { if (is_categorical) {
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""}; branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
@ -635,7 +660,7 @@ class GraphvizGenerator : public TreeGenerator {
} }
std::string buffer = std::string buffer =
SuperT::Match(kEdgeTemplate, SuperT::Match(kEdgeTemplate,
{{"{nid}", std::to_string(nid)}, {{"{nid}", std::to_string(nidx)},
{"{child}", std::to_string(child)}, {"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color}, {"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{branch}", branch}}); {"{branch}", branch}});
@ -644,68 +669,77 @@ class GraphvizGenerator : public TreeGenerator {
// Only indicator is different, so we combine all different node types into this // Only indicator is different, so we combine all different node types into this
// function. // function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override { std::string PlainNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
auto split_index = tree[nid].SplitIndex(); auto split_index = tree.SplitIndex(nidx);
auto cond = tree[nid].SplitCond(); auto cond = tree.SplitCond(nidx);
static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n"; static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
bool has_less = bool has_less =
(split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator; (split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator;
std::string result = std::string result =
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)}, SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)}, {"{fname}", GetFeatureName(fmap_, split_index)},
{"{<}", has_less ? "<" : ""}, {"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? SuperT::ToStr(cond) : ""}, {"{cond}", has_less ? ToStr(cond) : ""},
{"{params}", param_.condition_node_params}}); {"{params}", param_.condition_node_params}});
result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true); result += BuildEdge<false>(tree, nidx, tree.LeftChild(nidx), true);
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false); result += BuildEdge<false>(tree, nidx, tree.RightChild(nidx), false);
return result; return result;
}; };
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override { std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
static std::string const kLabelTemplate = static std::string const kLabelTemplate =
" {nid} [ label=\"{fname}:{cond}\" {params}]\n"; " {nid} [ label=\"{fname}:{cond}\" {params}]\n";
auto cats = GetSplitCategories(tree, nid); auto cats = GetSplitCategories(tree, nidx);
auto cats_str = PrintCatsAsSet(cats); auto cats_str = PrintCatsAsSet(cats);
auto split_index = tree[nid].SplitIndex(); auto split_index = tree.SplitIndex(nidx);
std::string result = std::string result =
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)}, SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)}, {"{fname}", GetFeatureName(fmap_, split_index)},
{"{cond}", cats_str}, {"{cond}", cats_str},
{"{params}", param_.condition_node_params}}); {"{params}", param_.condition_node_params}});
result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true); result += BuildEdge<true>(tree, nidx, tree.LeftChild(nidx), true);
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false); result += BuildEdge<true>(tree, nidx, tree.RightChild(nidx), false);
return result; return result;
} }
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override { std::string LeafNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
static std::string const kLeafTemplate = static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; // hardcoded limit to avoid dumping long arrays into dot graph.
auto result = SuperT::Match(kLeafTemplate, { bst_target_t constexpr kLimit{3};
{"{nid}", std::to_string(nid)}, if (tree.IsMultiTarget()) {
{"{leaf-value}", ToStr(tree[nid].LeafValue())}, auto value = tree.GetMultiTargetTree()->LeafValue(nidx);
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
{"{leaf-value}", ToStr(value, kLimit)},
{"{params}", param_.leaf_node_params}}); {"{params}", param_.leaf_node_params}});
return result; return result;
}; } else {
auto value = tree[nidx].LeafValue();
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
{"{leaf-value}", ToStr(value)},
{"{params}", param_.leaf_node_params}});
return result;
}
}
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string BuildTree(RegTree const& tree, bst_node_t nidx, uint32_t depth) override {
if (tree[nid].IsLeaf()) { if (tree.IsLeaf(nidx)) {
return this->LeafNode(tree, nid, depth); return this->LeafNode(tree, nidx, depth);
} }
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}"; static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical auto node = tree.GetSplitTypes()[nidx] == FeatureType::kCategorical
? this->Categorical(tree, nid, depth) ? this->Categorical(tree, nidx, depth)
: this->PlainNode(tree, nid, depth); : this->PlainNode(tree, nidx, depth);
auto result = SuperT::Match( auto result = SuperT::Match(
kNodeTemplate, kNodeTemplate,
{{"{parent}", node}, {{"{parent}", node},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)}, {"{left}", this->BuildTree(tree, tree.LeftChild(nidx), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}}); {"{right}", this->BuildTree(tree, tree.RightChild(nidx), depth+1)}});
return result; return result;
} }
@ -733,7 +767,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
constexpr bst_node_t RegTree::kRoot; constexpr bst_node_t RegTree::kRoot;
std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const { std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const {
CHECK(!IsMultiTarget()); if (this->IsMultiTarget() && format != "dot") {
LOG(FATAL) << format << " tree dump " << MTNotImplemented();
}
std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)}; std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)};
builder->BuildTree(*this); builder->BuildTree(*this);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023 by XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/context.h> // for Context #include <xgboost/context.h> // for Context
@ -7,16 +7,23 @@
#include <xgboost/tree_model.h> // for RegTree #include <xgboost/tree_model.h> // for RegTree
namespace xgboost { namespace xgboost {
TEST(MultiTargetTree, JsonIO) { namespace {
auto MakeTreeForTest() {
bst_target_t n_targets{3}; bst_target_t n_targets{3};
bst_feature_t n_features{4}; bst_feature_t n_features{4};
RegTree tree{n_targets, n_features}; RegTree tree{n_targets, n_features};
ASSERT_TRUE(tree.IsMultiTarget()); CHECK(tree.IsMultiTarget());
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()};
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()};
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()};
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(), tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
left_weight.HostView(), right_weight.HostView()); left_weight.HostView(), right_weight.HostView());
return tree;
}
} // namespace
TEST(MultiTargetTree, JsonIO) {
auto tree = MakeTreeForTest();
ASSERT_EQ(tree.NumNodes(), 3); ASSERT_EQ(tree.NumNodes(), 3);
ASSERT_EQ(tree.NumTargets(), 3); ASSERT_EQ(tree.NumTargets(), 3);
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3); ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
@ -44,4 +51,28 @@ TEST(MultiTargetTree, JsonIO) {
loaded.SaveModel(&jtree1); loaded.SaveModel(&jtree1);
check_jtree(jtree1, tree); check_jtree(jtree1, tree);
} }
TEST(MultiTargetTree, DumpDot) {
auto tree = MakeTreeForTest();
auto n_features = tree.NumFeatures();
FeatureMap fmap;
for (bst_feature_t f = 0; f < n_features; ++f) {
auto name = "feat_" + std::to_string(f);
fmap.PushBack(f, name.c_str(), "q");
}
auto str = tree.DumpModel(fmap, true, "dot");
ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos);
ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos);
{
bst_target_t n_targets{4};
bst_feature_t n_features{4};
RegTree tree{n_targets, n_features};
linalg::Vector<float> weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()};
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(),
weight.HostView(), weight.HostView());
auto str = tree.DumpModel(fmap, true, "dot");
ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos);
}
}
} // namespace xgboost } // namespace xgboost