Support graphviz plot for multi-target tree. (#10093)
This commit is contained in:
parent
e14c3b9325
commit
2c13f90384
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2014-2023 by Contributors
|
* Copyright 2014-2024, XGBoost Contributors
|
||||||
* \file tree_model.h
|
* \file tree_model.h
|
||||||
* \brief model structure for tree
|
* \brief model structure for tree
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -688,6 +688,9 @@ class RegTree : public Model {
|
|||||||
}
|
}
|
||||||
return (*this)[nidx].DefaultLeft();
|
return (*this)[nidx].DefaultLeft();
|
||||||
}
|
}
|
||||||
|
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
|
||||||
|
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
|
||||||
|
}
|
||||||
[[nodiscard]] bool IsRoot(bst_node_t nidx) const {
|
[[nodiscard]] bool IsRoot(bst_node_t nidx) const {
|
||||||
if (IsMultiTarget()) {
|
if (IsMultiTarget()) {
|
||||||
return nidx == kRoot;
|
return nidx == kRoot;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2015-2023, XGBoost Contributors
|
* Copyright 2015-2024, XGBoost Contributors
|
||||||
* \file tree_model.cc
|
* \file tree_model.cc
|
||||||
* \brief model structure for tree
|
* \brief model structure for tree
|
||||||
*/
|
*/
|
||||||
@ -8,6 +8,7 @@
|
|||||||
#include <xgboost/json.h>
|
#include <xgboost/json.h>
|
||||||
#include <xgboost/tree_model.h>
|
#include <xgboost/tree_model.h>
|
||||||
|
|
||||||
|
#include <array> // for array
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
@ -31,25 +32,49 @@ namespace tree {
|
|||||||
DMLC_REGISTER_PARAMETER(TrainParam);
|
DMLC_REGISTER_PARAMETER(TrainParam);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
namespace {
|
||||||
* \brief Base class for dump model implementation, modeling closely after code generator.
|
template <typename Float>
|
||||||
*/
|
std::enable_if_t<std::is_floating_point_v<Float>, std::string> ToStr(Float value) {
|
||||||
class TreeGenerator {
|
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
|
||||||
protected:
|
|
||||||
static int32_t constexpr kFloatMaxPrecision =
|
|
||||||
std::numeric_limits<bst_float>::max_digits10;
|
|
||||||
FeatureMap const& fmap_;
|
|
||||||
std::stringstream ss_;
|
|
||||||
bool const with_stats_;
|
|
||||||
|
|
||||||
template <typename Float>
|
|
||||||
static std::string ToStr(Float value) {
|
|
||||||
static_assert(std::is_floating_point<Float>::value,
|
static_assert(std::is_floating_point<Float>::value,
|
||||||
"Use std::to_string instead for non-floating point values.");
|
"Use std::to_string instead for non-floating point values.");
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << std::setprecision(kFloatMaxPrecision) << value;
|
ss << std::setprecision(kFloatMaxPrecision) << value;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Float>
|
||||||
|
std::string ToStr(linalg::VectorView<Float> value, bst_target_t limit) {
|
||||||
|
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
|
||||||
|
static_assert(std::is_floating_point<Float>::value,
|
||||||
|
"Use std::to_string instead for non-floating point values.");
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << std::setprecision(kFloatMaxPrecision);
|
||||||
|
if (value.Size() == 1) {
|
||||||
|
ss << value(0);
|
||||||
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
CHECK_GE(limit, 2);
|
||||||
|
auto n = std::min(static_cast<bst_target_t>(value.Size() - 1), limit - 1);
|
||||||
|
ss << "[";
|
||||||
|
for (std::size_t i = 0; i < n; ++i) {
|
||||||
|
ss << value(i) << ", ";
|
||||||
|
}
|
||||||
|
if (value.Size() > limit) {
|
||||||
|
ss << "..., ";
|
||||||
|
}
|
||||||
|
ss << value(value.Size() - 1) << "]";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
/*!
|
||||||
|
* \brief Base class for dump model implementation, modeling closely after code generator.
|
||||||
|
*/
|
||||||
|
class TreeGenerator {
|
||||||
|
protected:
|
||||||
|
FeatureMap const& fmap_;
|
||||||
|
std::stringstream ss_;
|
||||||
|
bool const with_stats_;
|
||||||
|
|
||||||
static std::string Tabs(uint32_t n) {
|
static std::string Tabs(uint32_t n) {
|
||||||
std::string res;
|
std::string res;
|
||||||
@ -258,10 +283,10 @@ class TextGenerator : public TreeGenerator {
|
|||||||
kLeafTemplate,
|
kLeafTemplate,
|
||||||
{{"{tabs}", SuperT::Tabs(depth)},
|
{{"{tabs}", SuperT::Tabs(depth)},
|
||||||
{"{nid}", std::to_string(nid)},
|
{"{nid}", std::to_string(nid)},
|
||||||
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
|
{"{leaf}", ToStr(tree[nid].LeafValue())},
|
||||||
{"{stats}", with_stats_ ?
|
{"{stats}", with_stats_ ?
|
||||||
SuperT::Match(kStatTemplate,
|
SuperT::Match(kStatTemplate,
|
||||||
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
{{"{cover}", ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,14 +336,14 @@ class TextGenerator : public TreeGenerator {
|
|||||||
static std::string const kQuantitiveTemplate =
|
static std::string const kQuantitiveTemplate =
|
||||||
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
static std::string const kNodeTemplate =
|
static std::string const kNodeTemplate =
|
||||||
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Categorical(RegTree const &tree, int32_t nid,
|
std::string Categorical(RegTree const &tree, int32_t nid,
|
||||||
@ -336,8 +361,8 @@ class TextGenerator : public TreeGenerator {
|
|||||||
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
||||||
std::string const result = SuperT::Match(
|
std::string const result = SuperT::Match(
|
||||||
kStatTemplate,
|
kStatTemplate,
|
||||||
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
|
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
|
||||||
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
|
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -393,11 +418,11 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
std::string result = SuperT::Match(
|
std::string result = SuperT::Match(
|
||||||
kLeafTemplate,
|
kLeafTemplate,
|
||||||
{{"{nid}", std::to_string(nid)},
|
{{"{nid}", std::to_string(nid)},
|
||||||
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
|
{"{leaf}", ToStr(tree[nid].LeafValue())},
|
||||||
{"{stat}", with_stats_ ? SuperT::Match(
|
{"{stat}", with_stats_ ? SuperT::Match(
|
||||||
kStatTemplate,
|
kStatTemplate,
|
||||||
{{"{sum_hess}",
|
{{"{sum_hess}",
|
||||||
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -468,7 +493,7 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||||
R"I("missing": {missing})I";
|
R"I("missing": {missing})I";
|
||||||
bst_float cond = tree[nid].SplitCond();
|
bst_float cond = tree[nid].SplitCond();
|
||||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
@ -477,7 +502,7 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||||
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||||
R"I("missing": {missing})I";
|
R"I("missing": {missing})I";
|
||||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
||||||
@ -485,8 +510,8 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
|
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
|
||||||
auto result = SuperT::Match(
|
auto result = SuperT::Match(
|
||||||
kStatTemplate,
|
kStatTemplate,
|
||||||
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
|
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
|
||||||
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
|
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -622,11 +647,11 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
template <bool is_categorical>
|
template <bool is_categorical>
|
||||||
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const {
|
std::string BuildEdge(RegTree const &tree, bst_node_t nidx, int32_t child, bool left) const {
|
||||||
static std::string const kEdgeTemplate =
|
static std::string const kEdgeTemplate =
|
||||||
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
|
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
|
||||||
// Is this the default child for missing value?
|
// Is this the default child for missing value?
|
||||||
bool is_missing = tree[nid].DefaultChild() == child;
|
bool is_missing = tree.DefaultChild(nidx) == child;
|
||||||
std::string branch;
|
std::string branch;
|
||||||
if (is_categorical) {
|
if (is_categorical) {
|
||||||
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
|
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
|
||||||
@ -635,7 +660,7 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
}
|
}
|
||||||
std::string buffer =
|
std::string buffer =
|
||||||
SuperT::Match(kEdgeTemplate,
|
SuperT::Match(kEdgeTemplate,
|
||||||
{{"{nid}", std::to_string(nid)},
|
{{"{nid}", std::to_string(nidx)},
|
||||||
{"{child}", std::to_string(child)},
|
{"{child}", std::to_string(child)},
|
||||||
{"{color}", is_missing ? param_.yes_color : param_.no_color},
|
{"{color}", is_missing ? param_.yes_color : param_.no_color},
|
||||||
{"{branch}", branch}});
|
{"{branch}", branch}});
|
||||||
@ -644,68 +669,77 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
|
|
||||||
// Only indicator is different, so we combine all different node types into this
|
// Only indicator is different, so we combine all different node types into this
|
||||||
// function.
|
// function.
|
||||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
std::string PlainNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
|
||||||
auto split_index = tree[nid].SplitIndex();
|
auto split_index = tree.SplitIndex(nidx);
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree.SplitCond(nidx);
|
||||||
static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
|
static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
|
||||||
|
|
||||||
bool has_less =
|
bool has_less =
|
||||||
(split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator;
|
(split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator;
|
||||||
std::string result =
|
std::string result =
|
||||||
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)},
|
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
|
||||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||||
{"{<}", has_less ? "<" : ""},
|
{"{<}", has_less ? "<" : ""},
|
||||||
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
{"{cond}", has_less ? ToStr(cond) : ""},
|
||||||
{"{params}", param_.condition_node_params}});
|
{"{params}", param_.condition_node_params}});
|
||||||
|
|
||||||
result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
|
result += BuildEdge<false>(tree, nidx, tree.LeftChild(nidx), true);
|
||||||
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);
|
result += BuildEdge<false>(tree, nidx, tree.RightChild(nidx), false);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override {
|
std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
|
||||||
static std::string const kLabelTemplate =
|
static std::string const kLabelTemplate =
|
||||||
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
|
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
|
||||||
auto cats = GetSplitCategories(tree, nid);
|
auto cats = GetSplitCategories(tree, nidx);
|
||||||
auto cats_str = PrintCatsAsSet(cats);
|
auto cats_str = PrintCatsAsSet(cats);
|
||||||
auto split_index = tree[nid].SplitIndex();
|
auto split_index = tree.SplitIndex(nidx);
|
||||||
|
|
||||||
std::string result =
|
std::string result =
|
||||||
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)},
|
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)},
|
||||||
{"{fname}", GetFeatureName(fmap_, split_index)},
|
{"{fname}", GetFeatureName(fmap_, split_index)},
|
||||||
{"{cond}", cats_str},
|
{"{cond}", cats_str},
|
||||||
{"{params}", param_.condition_node_params}});
|
{"{params}", param_.condition_node_params}});
|
||||||
|
|
||||||
result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
|
result += BuildEdge<true>(tree, nidx, tree.LeftChild(nidx), true);
|
||||||
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);
|
result += BuildEdge<true>(tree, nidx, tree.RightChild(nidx), false);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
std::string LeafNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
|
||||||
static std::string const kLeafTemplate =
|
static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
||||||
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
// hardcoded limit to avoid dumping long arrays into dot graph.
|
||||||
auto result = SuperT::Match(kLeafTemplate, {
|
bst_target_t constexpr kLimit{3};
|
||||||
{"{nid}", std::to_string(nid)},
|
if (tree.IsMultiTarget()) {
|
||||||
{"{leaf-value}", ToStr(tree[nid].LeafValue())},
|
auto value = tree.GetMultiTargetTree()->LeafValue(nidx);
|
||||||
|
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
|
||||||
|
{"{leaf-value}", ToStr(value, kLimit)},
|
||||||
{"{params}", param_.leaf_node_params}});
|
{"{params}", param_.leaf_node_params}});
|
||||||
return result;
|
return result;
|
||||||
};
|
} else {
|
||||||
|
auto value = tree[nidx].LeafValue();
|
||||||
|
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
|
||||||
|
{"{leaf-value}", ToStr(value)},
|
||||||
|
{"{params}", param_.leaf_node_params}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string BuildTree(RegTree const& tree, bst_node_t nidx, uint32_t depth) override {
|
||||||
if (tree[nid].IsLeaf()) {
|
if (tree.IsLeaf(nidx)) {
|
||||||
return this->LeafNode(tree, nid, depth);
|
return this->LeafNode(tree, nidx, depth);
|
||||||
}
|
}
|
||||||
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
|
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
|
||||||
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical
|
auto node = tree.GetSplitTypes()[nidx] == FeatureType::kCategorical
|
||||||
? this->Categorical(tree, nid, depth)
|
? this->Categorical(tree, nidx, depth)
|
||||||
: this->PlainNode(tree, nid, depth);
|
: this->PlainNode(tree, nidx, depth);
|
||||||
auto result = SuperT::Match(
|
auto result = SuperT::Match(
|
||||||
kNodeTemplate,
|
kNodeTemplate,
|
||||||
{{"{parent}", node},
|
{{"{parent}", node},
|
||||||
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
|
{"{left}", this->BuildTree(tree, tree.LeftChild(nidx), depth+1)},
|
||||||
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
|
{"{right}", this->BuildTree(tree, tree.RightChild(nidx), depth+1)}});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -733,7 +767,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
|
|||||||
constexpr bst_node_t RegTree::kRoot;
|
constexpr bst_node_t RegTree::kRoot;
|
||||||
|
|
||||||
std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const {
|
std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const {
|
||||||
CHECK(!IsMultiTarget());
|
if (this->IsMultiTarget() && format != "dot") {
|
||||||
|
LOG(FATAL) << format << " tree dump " << MTNotImplemented();
|
||||||
|
}
|
||||||
std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)};
|
std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)};
|
||||||
builder->BuildTree(*this);
|
builder->BuildTree(*this);
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023 by XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/context.h> // for Context
|
#include <xgboost/context.h> // for Context
|
||||||
@ -7,16 +7,23 @@
|
|||||||
#include <xgboost/tree_model.h> // for RegTree
|
#include <xgboost/tree_model.h> // for RegTree
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
TEST(MultiTargetTree, JsonIO) {
|
namespace {
|
||||||
|
auto MakeTreeForTest() {
|
||||||
bst_target_t n_targets{3};
|
bst_target_t n_targets{3};
|
||||||
bst_feature_t n_features{4};
|
bst_feature_t n_features{4};
|
||||||
RegTree tree{n_targets, n_features};
|
RegTree tree{n_targets, n_features};
|
||||||
ASSERT_TRUE(tree.IsMultiTarget());
|
CHECK(tree.IsMultiTarget());
|
||||||
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()};
|
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()};
|
||||||
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()};
|
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()};
|
||||||
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()};
|
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()};
|
||||||
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
|
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
|
||||||
left_weight.HostView(), right_weight.HostView());
|
left_weight.HostView(), right_weight.HostView());
|
||||||
|
return tree;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(MultiTargetTree, JsonIO) {
|
||||||
|
auto tree = MakeTreeForTest();
|
||||||
ASSERT_EQ(tree.NumNodes(), 3);
|
ASSERT_EQ(tree.NumNodes(), 3);
|
||||||
ASSERT_EQ(tree.NumTargets(), 3);
|
ASSERT_EQ(tree.NumTargets(), 3);
|
||||||
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
|
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
|
||||||
@ -44,4 +51,28 @@ TEST(MultiTargetTree, JsonIO) {
|
|||||||
loaded.SaveModel(&jtree1);
|
loaded.SaveModel(&jtree1);
|
||||||
check_jtree(jtree1, tree);
|
check_jtree(jtree1, tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MultiTargetTree, DumpDot) {
|
||||||
|
auto tree = MakeTreeForTest();
|
||||||
|
auto n_features = tree.NumFeatures();
|
||||||
|
FeatureMap fmap;
|
||||||
|
for (bst_feature_t f = 0; f < n_features; ++f) {
|
||||||
|
auto name = "feat_" + std::to_string(f);
|
||||||
|
fmap.PushBack(f, name.c_str(), "q");
|
||||||
|
}
|
||||||
|
auto str = tree.DumpModel(fmap, true, "dot");
|
||||||
|
ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos);
|
||||||
|
ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos);
|
||||||
|
|
||||||
|
{
|
||||||
|
bst_target_t n_targets{4};
|
||||||
|
bst_feature_t n_features{4};
|
||||||
|
RegTree tree{n_targets, n_features};
|
||||||
|
linalg::Vector<float> weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()};
|
||||||
|
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(),
|
||||||
|
weight.HostView(), weight.HostView());
|
||||||
|
auto str = tree.DumpModel(fmap, true, "dot");
|
||||||
|
ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user