Support categorical split in tree model dump. (#7036)

This commit is contained in:
Jiaming Yuan 2021-06-18 16:46:20 +08:00 committed by GitHub
parent 7968c0d051
commit 29f8fd6fee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 263 additions and 46 deletions

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014-2021 by Contributors
* \file feature_map.h * \file feature_map.h
* \brief Feature map data structure to help visualization and model dump. * \brief Feature map data structure to help visualization and model dump.
* \author Tianqi Chen * \author Tianqi Chen
@ -26,7 +26,8 @@ class FeatureMap {
kIndicator = 0, kIndicator = 0,
kQuantitive = 1, kQuantitive = 1,
kInteger = 2, kInteger = 2,
kFloat = 3 kFloat = 3,
kCategorical = 4
}; };
/*! /*!
* \brief load feature map from input stream * \brief load feature map from input stream
@ -82,6 +83,7 @@ class FeatureMap {
if (!strcmp("q", tname)) return kQuantitive; if (!strcmp("q", tname)) return kQuantitive;
if (!strcmp("int", tname)) return kInteger; if (!strcmp("int", tname)) return kInteger;
if (!strcmp("float", tname)) return kFloat; if (!strcmp("float", tname)) return kFloat;
if (!strcmp("categorical", tname)) return kCategorical;
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity"; LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
return kIndicator; return kIndicator;
} }

View File

@ -3,6 +3,7 @@
# coding: utf-8 # coding: utf-8
"""Plotting Library.""" """Plotting Library."""
from io import BytesIO from io import BytesIO
import json
import numpy as np import numpy as np
from .core import Booster from .core import Booster
from .sklearn import XGBModel from .sklearn import XGBModel
@ -203,7 +204,7 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
if kwargs: if kwargs:
parameters += ':' parameters += ':'
parameters += str(kwargs) parameters += json.dumps(kwargs)
tree = booster.get_dump( tree = booster.get_dump(
fmap=fmap, fmap=fmap,
dump_format=parameters)[num_trees] dump_format=parameters)[num_trees]

View File

@ -52,11 +52,6 @@ bst_float PredValue(const SparsePage::Inst &inst,
if (tree_info[i] == bst_group) { if (tree_info[i] == bst_group) {
auto const &tree = *trees[i]; auto const &tree = *trees[i];
bool has_categorical = tree.HasCategoricalSplit(); bool has_categorical = tree.HasCategoricalSplit();
auto categories = common::Span<uint32_t const>{tree.GetSplitCategories()};
auto split_types = tree.GetSplitTypes();
auto categories_ptr =
common::Span<RegTree::Segment const>{tree.GetSplitCategoriesPtr()};
auto cats = tree.GetCategoriesMatrix(); auto cats = tree.GetCategoriesMatrix();
bst_node_t nidx = -1; bst_node_t nidx = -1;
if (has_categorical) { if (has_categorical) {

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2020 by Contributors * Copyright 2015-2021 by Contributors
* \file tree_model.cc * \file tree_model.cc
* \brief model structure for tree * \brief model structure for tree
*/ */
@ -74,6 +74,7 @@ class TreeGenerator {
int32_t /*nid*/, uint32_t /*depth*/) const { int32_t /*nid*/, uint32_t /*depth*/) const {
return ""; return "";
} }
virtual std::string Categorical(RegTree const&, int32_t, uint32_t) const = 0;
virtual std::string Integer(RegTree const& /*tree*/, virtual std::string Integer(RegTree const& /*tree*/,
int32_t /*nid*/, uint32_t /*depth*/) const { int32_t /*nid*/, uint32_t /*depth*/) const {
return ""; return "";
@ -92,26 +93,51 @@ class TreeGenerator {
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) { virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
auto const split_index = tree[nid].SplitIndex(); auto const split_index = tree[nid].SplitIndex();
std::string result; std::string result;
auto is_categorical = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
if (split_index < fmap_.Size()) { if (split_index < fmap_.Size()) {
auto check_categorical = [&]() {
CHECK(is_categorical)
<< fmap_.Name(split_index)
<< " in feature map is numerical but tree node is categorical.";
};
auto check_numerical = [&]() {
auto is_numerical = !is_categorical;
CHECK(is_numerical)
<< fmap_.Name(split_index)
<< " in feature map is categorical but tree node is numerical.";
};
switch (fmap_.TypeOf(split_index)) { switch (fmap_.TypeOf(split_index)) {
case FeatureMap::kIndicator: { case FeatureMap::kCategorical: {
result = this->Indicator(tree, nid, depth); check_categorical();
break; result = this->Categorical(tree, nid, depth);
} break;
case FeatureMap::kInteger: { }
result = this->Integer(tree, nid, depth); case FeatureMap::kIndicator: {
break; check_numerical();
} result = this->Indicator(tree, nid, depth);
case FeatureMap::kFloat: break;
case FeatureMap::kQuantitive: { }
result = this->Quantitive(tree, nid, depth); case FeatureMap::kInteger: {
break; check_numerical();
} result = this->Integer(tree, nid, depth);
default: break;
LOG(FATAL) << "Unknown feature map type."; }
case FeatureMap::kFloat:
case FeatureMap::kQuantitive: {
check_numerical();
result = this->Quantitive(tree, nid, depth);
break;
}
default:
LOG(FATAL) << "Unknown feature map type.";
} }
} else { } else {
result = this->PlainNode(tree, nid, depth); if (is_categorical) {
result = this->Categorical(tree, nid, depth);
} else {
result = this->PlainNode(tree, nid, depth);
}
} }
return result; return result;
} }
@ -179,6 +205,32 @@ TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const&
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \ __make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name) ::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
std::vector<bst_cat_t> GetSplitCategories(RegTree const &tree, int32_t nidx) {
auto const &csr = tree.GetCategoriesMatrix();
auto seg = csr.node_ptr[nidx];
auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)};
std::vector<bst_cat_t> cats;
for (size_t i = 0; i < split.Size(); ++i) {
if (split.Check(i)) {
cats.push_back(static_cast<bst_cat_t>(i));
}
}
return cats;
}
std::string PrintCatsAsSet(std::vector<bst_cat_t> const &cats) {
std::stringstream ss;
ss << "{";
for (size_t i = 0; i < cats.size(); ++i) {
ss << cats[i];
if (i != cats.size() - 1) {
ss << ",";
}
}
ss << "}";
return ss.str();
}
class TextGenerator : public TreeGenerator { class TextGenerator : public TreeGenerator {
using SuperT = TreeGenerator; using SuperT = TreeGenerator;
@ -258,6 +310,17 @@ class TextGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
} }
std::string Categorical(RegTree const &tree, int32_t nid,
uint32_t depth) const override {
auto cats = GetSplitCategories(tree, nid);
std::string cats_str = PrintCatsAsSet(cats);
static std::string const kNodeTemplate =
"{tabs}{nid}:[{fname}:{cond}] yes={right},no={left},missing={missing}";
std::string const result =
SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth);
return result;
}
std::string NodeStat(RegTree const& tree, int32_t nid) const override { std::string NodeStat(RegTree const& tree, int32_t nid) const override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}"; static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match( std::string const result = SuperT::Match(
@ -343,6 +406,24 @@ class JsonGenerator : public TreeGenerator {
return result; return result;
} }
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cats = GetSplitCategories(tree, nid);
static std::string const kCategoryTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {right}, "no": {left}, )I"
R"I("missing": {missing})I";
std::string cats_ptr = "[";
for (size_t i = 0; i < cats.size(); ++i) {
cats_ptr += std::to_string(cats[i]);
if (i != cats.size() - 1) {
cats_ptr += ", ";
}
}
cats_ptr += "]";
auto results = SplitNodeImpl(tree, nid, kCategoryTemplate, cats_ptr, depth);
return results;
}
std::string SplitNodeImpl(RegTree const &tree, int32_t nid, std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
std::string const &template_str, std::string cond, std::string const &template_str, std::string cond,
uint32_t depth) const { uint32_t depth) const {
@ -534,6 +615,27 @@ class GraphvizGenerator : public TreeGenerator {
} }
protected: protected:
template <bool is_categorical>
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const {
static std::string const kEdgeTemplate =
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
// Is this the default child for missing value?
bool is_missing = tree[nid].DefaultChild() == child;
std::string branch;
if (is_categorical) {
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
} else {
branch = std::string{left ? "yes" : "no"} + std::string{is_missing ? ", missing" : ""};
}
std::string buffer =
SuperT::Match(kEdgeTemplate,
{{"{nid}", std::to_string(nid)},
{"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{branch}", branch}});
return buffer;
}
// Only indicator is different, so we combine all different node types into this // Only indicator is different, so we combine all different node types into this
// function. // function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override { std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
@ -552,27 +654,32 @@ class GraphvizGenerator : public TreeGenerator {
{"{cond}", has_less ? SuperT::ToStr(cond) : ""}, {"{cond}", has_less ? SuperT::ToStr(cond) : ""},
{"{params}", param_.condition_node_params}}); {"{params}", param_.condition_node_params}});
static std::string const kEdgeTemplate = result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n"; result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);
auto MatchFn = SuperT::Match; // mingw failed to capture protected fn.
auto BuildEdge =
[&tree, nid, MatchFn, this](int32_t child, bool left) {
// Is this the default child for missing value?
bool is_missing = tree[nid].DefaultChild() == child;
std::string branch = std::string {left ? "yes" : "no"} +
std::string {is_missing ? ", missing" : ""};
std::string buffer = MatchFn(kEdgeTemplate, {
{"{nid}", std::to_string(nid)},
{"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{branch}", branch}});
return buffer;
};
result += BuildEdge(tree[nid].LeftChild(), true);
result += BuildEdge(tree[nid].RightChild(), false);
return result; return result;
}; };
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override {
static std::string const kLabelTemplate =
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
auto cats = GetSplitCategories(tree, nid);
auto cats_str = PrintCatsAsSet(cats);
auto split = tree[nid].SplitIndex();
std::string result = SuperT::Match(
kLabelTemplate,
{{"{nid}", std::to_string(nid)},
{"{fname}", split < fmap_.Size() ? fmap_.Name(split)
: 'f' + std::to_string(split)},
{"{cond}", cats_str},
{"{params}", param_.condition_node_params}});
result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);
return result;
}
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override { std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
static std::string const kLeafTemplate = static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; " {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
@ -588,9 +695,12 @@ class GraphvizGenerator : public TreeGenerator {
return this->LeafNode(tree, nid, depth); return this->LeafNode(tree, nid, depth);
} }
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}"; static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical
? this->Categorical(tree, nid, depth)
: this->PlainNode(tree, nid, depth);
auto result = SuperT::Match( auto result = SuperT::Match(
kNodeTemplate, kNodeTemplate,
{{"{parent}", this->PlainNode(tree, nid, depth)}, {{"{parent}", node},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)}, {"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}}); {"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
return result; return result;

View File

@ -241,6 +241,65 @@ RegTree ConstructTree() {
/*right_sum=*/0.0f); /*right_sum=*/0.0f);
return tree; return tree;
} }
RegTree ConstructTreeCat(std::vector<bst_cat_t>* cond) {
RegTree tree;
std::vector<uint32_t> cats_storage(common::CatBitField::ComputeStorageSize(33), 0);
common::CatBitField split_cats(cats_storage);
split_cats.Set(0);
split_cats.Set(14);
split_cats.Set(32);
cond->push_back(0);
cond->push_back(14);
cond->push_back(32);
tree.ExpandCategorical(0, /*split_index=*/0, cats_storage, true, 0.0f, 2.0,
3.00, 11.0, 2.0, 3.0, 4.0);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
tree.ExpandCategorical(right, /*split_index=*/0, cats_storage, true, 0.0f,
2.0, 3.00, 11.0, 2.0, 3.0, 4.0);
return tree;
}
void TestCategoricalTreeDump(std::string format, std::string sep) {
std::vector<bst_cat_t> cond;
auto tree = ConstructTreeCat(&cond);
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, format);
std::string cond_str;
for (size_t c = 0; c < cond.size(); ++c) {
cond_str += std::to_string(cond[c]);
if (c != cond.size() - 1) {
cond_str += sep;
}
}
auto pos = str.find(cond_str);
ASSERT_NE(pos, std::string::npos);
pos = str.find(cond_str, pos + 1);
ASSERT_NE(pos, std::string::npos);
fmap.PushBack(0, "feat_0", "categorical");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, format);
pos = str.find(cond_str);
ASSERT_NE(pos, std::string::npos);
pos = str.find(cond_str, pos + 1);
ASSERT_NE(pos, std::string::npos);
if (format == "json") {
// Make sure it's valid JSON
Json::Load(StringView{str});
}
}
} // anonymous namespace } // anonymous namespace
TEST(Tree, DumpJson) { TEST(Tree, DumpJson) {
@ -278,6 +337,10 @@ TEST(Tree, DumpJson) {
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul); ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul);
} }
TEST(Tree, DumpJsonCategorical) {
TestCategoricalTreeDump("json", ", ");
}
TEST(Tree, DumpText) { TEST(Tree, DumpText) {
auto tree = ConstructTree(); auto tree = ConstructTree();
FeatureMap fmap; FeatureMap fmap;
@ -313,6 +376,10 @@ TEST(Tree, DumpText) {
ASSERT_EQ(str.find("cover"), std::string::npos); ASSERT_EQ(str.find("cover"), std::string::npos);
} }
TEST(Tree, DumpTextCategorical) {
TestCategoricalTreeDump("text", ",");
}
TEST(Tree, DumpDot) { TEST(Tree, DumpDot) {
auto tree = ConstructTree(); auto tree = ConstructTree();
FeatureMap fmap; FeatureMap fmap;
@ -350,6 +417,10 @@ TEST(Tree, DumpDot) {
ASSERT_NE(str.find(R"(1 -> 4 [label="no, missing")"), std::string::npos); ASSERT_NE(str.find(R"(1 -> 4 [label="no, missing")"), std::string::npos);
} }
TEST(Tree, DumpDotCategorical) {
TestCategoricalTreeDump("dot", ",");
}
TEST(Tree, JsonIO) { TEST(Tree, JsonIO) {
RegTree tree; RegTree tree;
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,

View File

@ -0,0 +1,40 @@
import sys
import xgboost as xgb
import pytest
import json
sys.path.append("tests/python")
import testing as tm
try:
import matplotlib
matplotlib.use("Agg")
from matplotlib.axes import Axes
from graphviz import Source
except ImportError:
pass
pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_matplotlib(), tm.no_graphviz()))
class TestPlotting:
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self):
X, y = tm.make_categorical(1000, 31, 19, onehot=False)
reg = xgb.XGBRegressor(
enable_categorical=True, n_estimators=10, tree_method="gpu_hist"
)
reg.fit(X, y)
trees = reg.get_booster().get_dump(dump_format="json")
for tree in trees:
j_tree = json.loads(tree)
assert "leaf" in j_tree.keys() or isinstance(
j_tree["split_condition"], list
)
graph = xgb.to_graphviz(reg, num_trees=len(j_tree) - 1)
assert isinstance(graph, Source)
ax = xgb.plot_tree(reg, num_trees=len(j_tree) - 1)
assert isinstance(ax, Axes)

View File

@ -71,7 +71,6 @@ class TestGPUUpdaters:
@settings(deadline=None) @settings(deadline=None)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats): def test_categorical(self, rows, cols, rounds, cats):
pytest.xfail(reason='TestGPUUpdaters::test_categorical is flaky')
self.run_categorical_basic(rows, cols, rounds, cats) self.run_categorical_basic(rows, cols, rounds, cats)
def test_categorical_32_cat(self): def test_categorical_32_cat(self):

View File

@ -55,7 +55,6 @@ def test_categorical():
tree_method="gpu_hist", tree_method="gpu_hist",
use_label_encoder=False, use_label_encoder=False,
enable_categorical=True, enable_categorical=True,
predictor="gpu_predictor",
n_estimators=10, n_estimators=10,
) )
X = pd.DataFrame(X.todense()).astype("category") X = pd.DataFrame(X.todense()).astype("category")