Support categorical split in tree model dump. (#7036)

This commit is contained in:
Jiaming Yuan
2021-06-18 16:46:20 +08:00
committed by GitHub
parent 7968c0d051
commit 29f8fd6fee
8 changed files with 263 additions and 46 deletions

View File

@@ -241,6 +241,65 @@ RegTree ConstructTree() {
/*right_sum=*/0.0f);
return tree;
}
RegTree ConstructTreeCat(std::vector<bst_cat_t>* cond) {
RegTree tree;
std::vector<uint32_t> cats_storage(common::CatBitField::ComputeStorageSize(33), 0);
common::CatBitField split_cats(cats_storage);
split_cats.Set(0);
split_cats.Set(14);
split_cats.Set(32);
cond->push_back(0);
cond->push_back(14);
cond->push_back(32);
tree.ExpandCategorical(0, /*split_index=*/0, cats_storage, true, 0.0f, 2.0,
3.00, 11.0, 2.0, 3.0, 4.0);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
tree.ExpandCategorical(right, /*split_index=*/0, cats_storage, true, 0.0f,
2.0, 3.00, 11.0, 2.0, 3.0, 4.0);
return tree;
}
void TestCategoricalTreeDump(std::string format, std::string sep) {
std::vector<bst_cat_t> cond;
auto tree = ConstructTreeCat(&cond);
FeatureMap fmap;
auto str = tree.DumpModel(fmap, true, format);
std::string cond_str;
for (size_t c = 0; c < cond.size(); ++c) {
cond_str += std::to_string(cond[c]);
if (c != cond.size() - 1) {
cond_str += sep;
}
}
auto pos = str.find(cond_str);
ASSERT_NE(pos, std::string::npos);
pos = str.find(cond_str, pos + 1);
ASSERT_NE(pos, std::string::npos);
fmap.PushBack(0, "feat_0", "categorical");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");
str = tree.DumpModel(fmap, true, format);
pos = str.find(cond_str);
ASSERT_NE(pos, std::string::npos);
pos = str.find(cond_str, pos + 1);
ASSERT_NE(pos, std::string::npos);
if (format == "json") {
// Make sure it's valid JSON
Json::Load(StringView{str});
}
}
} // anonymous namespace
TEST(Tree, DumpJson) {
@@ -278,6 +337,10 @@ TEST(Tree, DumpJson) {
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul);
}
TEST(Tree, DumpJsonCategorical) {
TestCategoricalTreeDump("json", ", ");
}
TEST(Tree, DumpText) {
auto tree = ConstructTree();
FeatureMap fmap;
@@ -313,6 +376,10 @@ TEST(Tree, DumpText) {
ASSERT_EQ(str.find("cover"), std::string::npos);
}
TEST(Tree, DumpTextCategorical) {
TestCategoricalTreeDump("text", ",");
}
TEST(Tree, DumpDot) {
auto tree = ConstructTree();
FeatureMap fmap;
@@ -350,6 +417,10 @@ TEST(Tree, DumpDot) {
ASSERT_NE(str.find(R"(1 -> 4 [label="no, missing")"), std::string::npos);
}
TEST(Tree, DumpDotCategorical) {
TestCategoricalTreeDump("dot", ",");
}
TEST(Tree, JsonIO) {
RegTree tree;
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,