Support graphviz plot for multi-target tree. (#10093)
This commit is contained in:
parent
e14c3b9325
commit
2c13f90384
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023 by Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file tree_model.h
|
||||
* \brief model structure for tree
|
||||
* \author Tianqi Chen
|
||||
@ -688,6 +688,9 @@ class RegTree : public Model {
|
||||
}
|
||||
return (*this)[nidx].DefaultLeft();
|
||||
}
|
||||
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
|
||||
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
|
||||
}
|
||||
[[nodiscard]] bool IsRoot(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return nidx == kRoot;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2015-2023, XGBoost Contributors
|
||||
* Copyright 2015-2024, XGBoost Contributors
|
||||
* \file tree_model.cc
|
||||
* \brief model structure for tree
|
||||
*/
|
||||
@ -8,6 +8,7 @@
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
|
||||
#include <array> // for array
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
@ -31,19 +32,10 @@ namespace tree {
|
||||
DMLC_REGISTER_PARAMETER(TrainParam);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Base class for dump model implementation, modeling closely after code generator.
|
||||
*/
|
||||
class TreeGenerator {
|
||||
protected:
|
||||
static int32_t constexpr kFloatMaxPrecision =
|
||||
std::numeric_limits<bst_float>::max_digits10;
|
||||
FeatureMap const& fmap_;
|
||||
std::stringstream ss_;
|
||||
bool const with_stats_;
|
||||
|
||||
namespace {
|
||||
template <typename Float>
|
||||
static std::string ToStr(Float value) {
|
||||
std::enable_if_t<std::is_floating_point_v<Float>, std::string> ToStr(Float value) {
|
||||
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
|
||||
static_assert(std::is_floating_point<Float>::value,
|
||||
"Use std::to_string instead for non-floating point values.");
|
||||
std::stringstream ss;
|
||||
@ -51,6 +43,39 @@ class TreeGenerator {
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename Float>
|
||||
std::string ToStr(linalg::VectorView<Float> value, bst_target_t limit) {
|
||||
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
|
||||
static_assert(std::is_floating_point<Float>::value,
|
||||
"Use std::to_string instead for non-floating point values.");
|
||||
std::stringstream ss;
|
||||
ss << std::setprecision(kFloatMaxPrecision);
|
||||
if (value.Size() == 1) {
|
||||
ss << value(0);
|
||||
return ss.str();
|
||||
}
|
||||
CHECK_GE(limit, 2);
|
||||
auto n = std::min(static_cast<bst_target_t>(value.Size() - 1), limit - 1);
|
||||
ss << "[";
|
||||
for (std::size_t i = 0; i < n; ++i) {
|
||||
ss << value(i) << ", ";
|
||||
}
|
||||
if (value.Size() > limit) {
|
||||
ss << "..., ";
|
||||
}
|
||||
ss << value(value.Size() - 1) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
} // namespace
|
||||
/*!
|
||||
* \brief Base class for dump model implementation, modeling closely after code generator.
|
||||
*/
|
||||
class TreeGenerator {
|
||||
protected:
|
||||
FeatureMap const& fmap_;
|
||||
std::stringstream ss_;
|
||||
bool const with_stats_;
|
||||
|
||||
static std::string Tabs(uint32_t n) {
|
||||
std::string res;
|
||||
for (uint32_t i = 0; i < n; ++i) {
|
||||
@ -258,10 +283,10 @@ class TextGenerator : public TreeGenerator {
|
||||
kLeafTemplate,
|
||||
{{"{tabs}", SuperT::Tabs(depth)},
|
||||
{"{nid}", std::to_string(nid)},
|
||||
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
|
||||
{"{leaf}", ToStr(tree[nid].LeafValue())},
|
||||
{"{stats}", with_stats_ ?
|
||||
SuperT::Match(kStatTemplate,
|
||||
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
||||
{{"{cover}", ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -311,14 +336,14 @@ class TextGenerator : public TreeGenerator {
|
||||
static std::string const kQuantitiveTemplate =
|
||||
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||
auto cond = tree[nid].SplitCond();
|
||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
|
||||
}
|
||||
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
auto cond = tree[nid].SplitCond();
|
||||
static std::string const kNodeTemplate =
|
||||
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
|
||||
}
|
||||
|
||||
std::string Categorical(RegTree const &tree, int32_t nid,
|
||||
@ -336,8 +361,8 @@ class TextGenerator : public TreeGenerator {
|
||||
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
||||
std::string const result = SuperT::Match(
|
||||
kStatTemplate,
|
||||
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
|
||||
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
|
||||
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
|
||||
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -393,11 +418,11 @@ class JsonGenerator : public TreeGenerator {
|
||||
std::string result = SuperT::Match(
|
||||
kLeafTemplate,
|
||||
{{"{nid}", std::to_string(nid)},
|
||||
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
|
||||
{"{leaf}", ToStr(tree[nid].LeafValue())},
|
||||
{"{stat}", with_stats_ ? SuperT::Match(
|
||||
kStatTemplate,
|
||||
{{"{sum_hess}",
|
||||
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
||||
ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -468,7 +493,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||
R"I("missing": {missing})I";
|
||||
bst_float cond = tree[nid].SplitCond();
|
||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
|
||||
}
|
||||
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
@ -477,7 +502,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||
R"I("missing": {missing})I";
|
||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
|
||||
}
|
||||
|
||||
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
||||
@ -485,8 +510,8 @@ class JsonGenerator : public TreeGenerator {
|
||||
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
|
||||
auto result = SuperT::Match(
|
||||
kStatTemplate,
|
||||
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
|
||||
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
|
||||
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
|
||||
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -622,11 +647,11 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
|
||||
protected:
|
||||
template <bool is_categorical>
|
||||
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const {
|
||||
std::string BuildEdge(RegTree const &tree, bst_node_t nidx, int32_t child, bool left) const {
|
||||
static std::string const kEdgeTemplate =
|
||||
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
|
||||
// Is this the default child for missing value?
|
||||
bool is_missing = tree[nid].DefaultChild() == child;
|
||||
bool is_missing = tree.DefaultChild(nidx) == child;
|
||||
std::string branch;
|
||||
if (is_categorical) {
|
||||
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
|
||||
@ -635,7 +660,7 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
}
|
||||
std::string buffer =
|
||||
SuperT::Match(kEdgeTemplate,
|
||||
{{"{nid}", std::to_string(nid)},
|
||||
{{"{nid}", std::to_string(nidx)},
|
||||
{"{child}", std::to_string(child)},
|
||||
{"{color}", is_missing ? param_.yes_color : param_.no_color},
|
||||
{"{branch}", branch}});
|
||||
@ -644,68 +669,77 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
|
||||
// Only indicator is different, so we combine all different node types into this
|
||||
// function.
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||
auto split_index = tree[nid].SplitIndex();
|
||||
auto cond = tree[nid].SplitCond();
|
||||
std::string PlainNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
|
||||
auto split_index = tree.SplitIndex(nidx);
|
||||
auto cond = tree.SplitCond(nidx);
|
||||
static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
|
||||
|
||||
bool has_less =
|
||||
(split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator;
|
||||
std::string result =
|
||||
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)},
|
||||
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
|
||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||
{"{<}", has_less ? "<" : ""},
|
||||
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
||||
{"{cond}", has_less ? ToStr(cond) : ""},
|
||||
{"{params}", param_.condition_node_params}});
|
||||
|
||||
result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
|
||||
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);
|
||||
result += BuildEdge<false>(tree, nidx, tree.LeftChild(nidx), true);
|
||||
result += BuildEdge<false>(tree, nidx, tree.RightChild(nidx), false);
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||
std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
|
||||
static std::string const kLabelTemplate =
|
||||
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
|
||||
auto cats = GetSplitCategories(tree, nid);
|
||||
auto cats = GetSplitCategories(tree, nidx);
|
||||
auto cats_str = PrintCatsAsSet(cats);
|
||||
auto split_index = tree[nid].SplitIndex();
|
||||
auto split_index = tree.SplitIndex(nidx);
|
||||
|
||||
std::string result =
|
||||
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)},
|
||||
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)},
|
||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||
{"{cond}", cats_str},
|
||||
{"{params}", param_.condition_node_params}});
|
||||
|
||||
result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
|
||||
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);
|
||||
result += BuildEdge<true>(tree, nidx, tree.LeftChild(nidx), true);
|
||||
result += BuildEdge<true>(tree, nidx, tree.RightChild(nidx), false);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||
static std::string const kLeafTemplate =
|
||||
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
||||
auto result = SuperT::Match(kLeafTemplate, {
|
||||
{"{nid}", std::to_string(nid)},
|
||||
{"{leaf-value}", ToStr(tree[nid].LeafValue())},
|
||||
std::string LeafNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
|
||||
static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
||||
// hardcoded limit to avoid dumping long arrays into dot graph.
|
||||
bst_target_t constexpr kLimit{3};
|
||||
if (tree.IsMultiTarget()) {
|
||||
auto value = tree.GetMultiTargetTree()->LeafValue(nidx);
|
||||
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
|
||||
{"{leaf-value}", ToStr(value, kLimit)},
|
||||
{"{params}", param_.leaf_node_params}});
|
||||
return result;
|
||||
};
|
||||
} else {
|
||||
auto value = tree[nidx].LeafValue();
|
||||
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
|
||||
{"{leaf-value}", ToStr(value)},
|
||||
{"{params}", param_.leaf_node_params}});
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
if (tree[nid].IsLeaf()) {
|
||||
return this->LeafNode(tree, nid, depth);
|
||||
std::string BuildTree(RegTree const& tree, bst_node_t nidx, uint32_t depth) override {
|
||||
if (tree.IsLeaf(nidx)) {
|
||||
return this->LeafNode(tree, nidx, depth);
|
||||
}
|
||||
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
|
||||
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical
|
||||
? this->Categorical(tree, nid, depth)
|
||||
: this->PlainNode(tree, nid, depth);
|
||||
auto node = tree.GetSplitTypes()[nidx] == FeatureType::kCategorical
|
||||
? this->Categorical(tree, nidx, depth)
|
||||
: this->PlainNode(tree, nidx, depth);
|
||||
auto result = SuperT::Match(
|
||||
kNodeTemplate,
|
||||
{{"{parent}", node},
|
||||
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
|
||||
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
|
||||
{"{left}", this->BuildTree(tree, tree.LeftChild(nidx), depth+1)},
|
||||
{"{right}", this->BuildTree(tree, tree.RightChild(nidx), depth+1)}});
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -733,7 +767,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
|
||||
constexpr bst_node_t RegTree::kRoot;
|
||||
|
||||
std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const {
|
||||
CHECK(!IsMultiTarget());
|
||||
if (this->IsMultiTarget() && format != "dot") {
|
||||
LOG(FATAL) << format << " tree dump " << MTNotImplemented();
|
||||
}
|
||||
std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)};
|
||||
builder->BuildTree(*this);
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h> // for Context
|
||||
@ -7,16 +7,23 @@
|
||||
#include <xgboost/tree_model.h> // for RegTree
|
||||
|
||||
namespace xgboost {
|
||||
TEST(MultiTargetTree, JsonIO) {
|
||||
namespace {
|
||||
auto MakeTreeForTest() {
|
||||
bst_target_t n_targets{3};
|
||||
bst_feature_t n_features{4};
|
||||
RegTree tree{n_targets, n_features};
|
||||
ASSERT_TRUE(tree.IsMultiTarget());
|
||||
CHECK(tree.IsMultiTarget());
|
||||
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()};
|
||||
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()};
|
||||
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()};
|
||||
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
|
||||
left_weight.HostView(), right_weight.HostView());
|
||||
return tree;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(MultiTargetTree, JsonIO) {
|
||||
auto tree = MakeTreeForTest();
|
||||
ASSERT_EQ(tree.NumNodes(), 3);
|
||||
ASSERT_EQ(tree.NumTargets(), 3);
|
||||
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
|
||||
@ -44,4 +51,28 @@ TEST(MultiTargetTree, JsonIO) {
|
||||
loaded.SaveModel(&jtree1);
|
||||
check_jtree(jtree1, tree);
|
||||
}
|
||||
|
||||
TEST(MultiTargetTree, DumpDot) {
|
||||
auto tree = MakeTreeForTest();
|
||||
auto n_features = tree.NumFeatures();
|
||||
FeatureMap fmap;
|
||||
for (bst_feature_t f = 0; f < n_features; ++f) {
|
||||
auto name = "feat_" + std::to_string(f);
|
||||
fmap.PushBack(f, name.c_str(), "q");
|
||||
}
|
||||
auto str = tree.DumpModel(fmap, true, "dot");
|
||||
ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos);
|
||||
ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos);
|
||||
|
||||
{
|
||||
bst_target_t n_targets{4};
|
||||
bst_feature_t n_features{4};
|
||||
RegTree tree{n_targets, n_features};
|
||||
linalg::Vector<float> weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()};
|
||||
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(),
|
||||
weight.HostView(), weight.HostView());
|
||||
auto str = tree.DumpModel(fmap, true, "dot");
|
||||
ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user