Support categorical split in tree model dump. (#7036)
This commit is contained in:
@@ -52,11 +52,6 @@ bst_float PredValue(const SparsePage::Inst &inst,
|
||||
if (tree_info[i] == bst_group) {
|
||||
auto const &tree = *trees[i];
|
||||
bool has_categorical = tree.HasCategoricalSplit();
|
||||
|
||||
auto categories = common::Span<uint32_t const>{tree.GetSplitCategories()};
|
||||
auto split_types = tree.GetSplitTypes();
|
||||
auto categories_ptr =
|
||||
common::Span<RegTree::Segment const>{tree.GetSplitCategoriesPtr()};
|
||||
auto cats = tree.GetCategoriesMatrix();
|
||||
bst_node_t nidx = -1;
|
||||
if (has_categorical) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2020 by Contributors
|
||||
* Copyright 2015-2021 by Contributors
|
||||
* \file tree_model.cc
|
||||
* \brief model structure for tree
|
||||
*/
|
||||
@@ -74,6 +74,7 @@ class TreeGenerator {
|
||||
int32_t /*nid*/, uint32_t /*depth*/) const {
|
||||
return "";
|
||||
}
|
||||
virtual std::string Categorical(RegTree const&, int32_t, uint32_t) const = 0;
|
||||
virtual std::string Integer(RegTree const& /*tree*/,
|
||||
int32_t /*nid*/, uint32_t /*depth*/) const {
|
||||
return "";
|
||||
@@ -92,26 +93,51 @@ class TreeGenerator {
|
||||
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||
auto const split_index = tree[nid].SplitIndex();
|
||||
std::string result;
|
||||
auto is_categorical = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||
if (split_index < fmap_.Size()) {
|
||||
auto check_categorical = [&]() {
|
||||
CHECK(is_categorical)
|
||||
<< fmap_.Name(split_index)
|
||||
<< " in feature map is numerical but tree node is categorical.";
|
||||
};
|
||||
auto check_numerical = [&]() {
|
||||
auto is_numerical = !is_categorical;
|
||||
CHECK(is_numerical)
|
||||
<< fmap_.Name(split_index)
|
||||
<< " in feature map is categorical but tree node is numerical.";
|
||||
};
|
||||
|
||||
switch (fmap_.TypeOf(split_index)) {
|
||||
case FeatureMap::kIndicator: {
|
||||
result = this->Indicator(tree, nid, depth);
|
||||
break;
|
||||
}
|
||||
case FeatureMap::kInteger: {
|
||||
result = this->Integer(tree, nid, depth);
|
||||
break;
|
||||
}
|
||||
case FeatureMap::kFloat:
|
||||
case FeatureMap::kQuantitive: {
|
||||
result = this->Quantitive(tree, nid, depth);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unknown feature map type.";
|
||||
case FeatureMap::kCategorical: {
|
||||
check_categorical();
|
||||
result = this->Categorical(tree, nid, depth);
|
||||
break;
|
||||
}
|
||||
case FeatureMap::kIndicator: {
|
||||
check_numerical();
|
||||
result = this->Indicator(tree, nid, depth);
|
||||
break;
|
||||
}
|
||||
case FeatureMap::kInteger: {
|
||||
check_numerical();
|
||||
result = this->Integer(tree, nid, depth);
|
||||
break;
|
||||
}
|
||||
case FeatureMap::kFloat:
|
||||
case FeatureMap::kQuantitive: {
|
||||
check_numerical();
|
||||
result = this->Quantitive(tree, nid, depth);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unknown feature map type.";
|
||||
}
|
||||
} else {
|
||||
result = this->PlainNode(tree, nid, depth);
|
||||
if (is_categorical) {
|
||||
result = this->Categorical(tree, nid, depth);
|
||||
} else {
|
||||
result = this->PlainNode(tree, nid, depth);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -179,6 +205,32 @@ TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const&
|
||||
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
|
||||
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
|
||||
|
||||
std::vector<bst_cat_t> GetSplitCategories(RegTree const &tree, int32_t nidx) {
|
||||
auto const &csr = tree.GetCategoriesMatrix();
|
||||
auto seg = csr.node_ptr[nidx];
|
||||
auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)};
|
||||
|
||||
std::vector<bst_cat_t> cats;
|
||||
for (size_t i = 0; i < split.Size(); ++i) {
|
||||
if (split.Check(i)) {
|
||||
cats.push_back(static_cast<bst_cat_t>(i));
|
||||
}
|
||||
}
|
||||
return cats;
|
||||
}
|
||||
|
||||
std::string PrintCatsAsSet(std::vector<bst_cat_t> const &cats) {
|
||||
std::stringstream ss;
|
||||
ss << "{";
|
||||
for (size_t i = 0; i < cats.size(); ++i) {
|
||||
ss << cats[i];
|
||||
if (i != cats.size() - 1) {
|
||||
ss << ",";
|
||||
}
|
||||
}
|
||||
ss << "}";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
class TextGenerator : public TreeGenerator {
|
||||
using SuperT = TreeGenerator;
|
||||
@@ -258,6 +310,17 @@ class TextGenerator : public TreeGenerator {
|
||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||
}
|
||||
|
||||
std::string Categorical(RegTree const &tree, int32_t nid,
|
||||
uint32_t depth) const override {
|
||||
auto cats = GetSplitCategories(tree, nid);
|
||||
std::string cats_str = PrintCatsAsSet(cats);
|
||||
static std::string const kNodeTemplate =
|
||||
"{tabs}{nid}:[{fname}:{cond}] yes={right},no={left},missing={missing}";
|
||||
std::string const result =
|
||||
SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
||||
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
||||
std::string const result = SuperT::Match(
|
||||
@@ -343,6 +406,24 @@ class JsonGenerator : public TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
auto cats = GetSplitCategories(tree, nid);
|
||||
static std::string const kCategoryTemplate =
|
||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||
R"I("split_condition": {cond}, "yes": {right}, "no": {left}, )I"
|
||||
R"I("missing": {missing})I";
|
||||
std::string cats_ptr = "[";
|
||||
for (size_t i = 0; i < cats.size(); ++i) {
|
||||
cats_ptr += std::to_string(cats[i]);
|
||||
if (i != cats.size() - 1) {
|
||||
cats_ptr += ", ";
|
||||
}
|
||||
}
|
||||
cats_ptr += "]";
|
||||
auto results = SplitNodeImpl(tree, nid, kCategoryTemplate, cats_ptr, depth);
|
||||
return results;
|
||||
}
|
||||
|
||||
std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
|
||||
std::string const &template_str, std::string cond,
|
||||
uint32_t depth) const {
|
||||
@@ -534,6 +615,27 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
}
|
||||
|
||||
protected:
|
||||
template <bool is_categorical>
|
||||
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const {
|
||||
static std::string const kEdgeTemplate =
|
||||
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
|
||||
// Is this the default child for missing value?
|
||||
bool is_missing = tree[nid].DefaultChild() == child;
|
||||
std::string branch;
|
||||
if (is_categorical) {
|
||||
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
|
||||
} else {
|
||||
branch = std::string{left ? "yes" : "no"} + std::string{is_missing ? ", missing" : ""};
|
||||
}
|
||||
std::string buffer =
|
||||
SuperT::Match(kEdgeTemplate,
|
||||
{{"{nid}", std::to_string(nid)},
|
||||
{"{child}", std::to_string(child)},
|
||||
{"{color}", is_missing ? param_.yes_color : param_.no_color},
|
||||
{"{branch}", branch}});
|
||||
return buffer;
|
||||
}
|
||||
|
||||
// Only indicator is different, so we combine all different node types into this
|
||||
// function.
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||
@@ -552,27 +654,32 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
||||
{"{params}", param_.condition_node_params}});
|
||||
|
||||
static std::string const kEdgeTemplate =
|
||||
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
|
||||
auto MatchFn = SuperT::Match; // mingw failed to capture protected fn.
|
||||
auto BuildEdge =
|
||||
[&tree, nid, MatchFn, this](int32_t child, bool left) {
|
||||
// Is this the default child for missing value?
|
||||
bool is_missing = tree[nid].DefaultChild() == child;
|
||||
std::string branch = std::string {left ? "yes" : "no"} +
|
||||
std::string {is_missing ? ", missing" : ""};
|
||||
std::string buffer = MatchFn(kEdgeTemplate, {
|
||||
{"{nid}", std::to_string(nid)},
|
||||
{"{child}", std::to_string(child)},
|
||||
{"{color}", is_missing ? param_.yes_color : param_.no_color},
|
||||
{"{branch}", branch}});
|
||||
return buffer;
|
||||
};
|
||||
result += BuildEdge(tree[nid].LeftChild(), true);
|
||||
result += BuildEdge(tree[nid].RightChild(), false);
|
||||
result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
|
||||
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||
static std::string const kLabelTemplate =
|
||||
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
|
||||
auto cats = GetSplitCategories(tree, nid);
|
||||
auto cats_str = PrintCatsAsSet(cats);
|
||||
auto split = tree[nid].SplitIndex();
|
||||
std::string result = SuperT::Match(
|
||||
kLabelTemplate,
|
||||
{{"{nid}", std::to_string(nid)},
|
||||
{"{fname}", split < fmap_.Size() ? fmap_.Name(split)
|
||||
: 'f' + std::to_string(split)},
|
||||
{"{cond}", cats_str},
|
||||
{"{params}", param_.condition_node_params}});
|
||||
|
||||
result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
|
||||
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
|
||||
static std::string const kLeafTemplate =
|
||||
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
||||
@@ -588,9 +695,12 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
return this->LeafNode(tree, nid, depth);
|
||||
}
|
||||
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
|
||||
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical
|
||||
? this->Categorical(tree, nid, depth)
|
||||
: this->PlainNode(tree, nid, depth);
|
||||
auto result = SuperT::Match(
|
||||
kNodeTemplate,
|
||||
{{"{parent}", this->PlainNode(tree, nid, depth)},
|
||||
{{"{parent}", node},
|
||||
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
|
||||
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
|
||||
return result;
|
||||
|
||||
Reference in New Issue
Block a user