Expand categorical node. (#6028)

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-08-25 18:53:57 +08:00
committed by GitHub
parent 9a4e8b1d81
commit 20c95be625
12 changed files with 340 additions and 103 deletions

View File

@@ -1,9 +1,10 @@
// Copyright by Contributors
#include <gtest/gtest.h>
#include <xgboost/tree_model.h>
#include "../helpers.h"
#include "dmlc/filesystem.h"
#include "xgboost/json_io.h"
#include "xgboost/tree_model.h"
#include "../../../src/common/bitfield.h"
namespace xgboost {
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
@@ -82,7 +83,7 @@ TEST(Tree, Load) {
tree.Load(fi.get());
EXPECT_EQ(tree.GetDepth(1), 1);
EXPECT_EQ(tree[0].SplitCond(), 0.5f);
EXPECT_EQ(tree[0].SplitIndex(), 5);
EXPECT_EQ(tree[0].SplitIndex(), 5ul);
EXPECT_EQ(tree[1].LeafValue(), 0.1f);
EXPECT_TRUE(tree[1].IsLeaf());
}
@@ -105,6 +106,51 @@ TEST(Tree, AllocateNode) {
ASSERT_TRUE(nodes.at(2).IsLeaf());
}
TEST(Tree, ExpandCategoricalFeature) {
{
RegTree tree;
tree.ExpandCategorical(0, 0, {}, true, 1.0, 2.0, 3.0, 11.0, 2.0,
/*left_sum=*/3.0, /*right_sum=*/4.0);
ASSERT_EQ(tree.GetNodes().size(), 3ul);
ASSERT_EQ(tree.GetNumLeaves(), 2);
ASSERT_EQ(tree.GetSplitTypes().size(), 3ul);
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
ASSERT_EQ(tree.GetSplitTypes()[1], FeatureType::kNumerical);
ASSERT_EQ(tree.GetSplitTypes()[2], FeatureType::kNumerical);
ASSERT_EQ(tree.GetSplitCategories().size(), 0ul);
ASSERT_TRUE(std::isnan(tree[0].SplitCond()));
}
{
RegTree tree;
bst_cat_t cat = 33;
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(cat+1));
LBitField32 bitset {split_cats};
bitset.Set(cat);
tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
/*left_sum=*/3.0, /*right_sum=*/4.0);
auto categories = tree.GetSplitCategories();
auto segments = tree.GetSplitCategoriesPtr();
auto got = categories.subspan(segments[0].beg, segments[0].size);
ASSERT_TRUE(std::equal(got.cbegin(), got.cend(), split_cats.cbegin()));
Json out{Object()};
tree.SaveModel(&out);
RegTree loaded_tree;
loaded_tree.LoadModel(out);
auto const& cat_ptr = loaded_tree.GetSplitCategoriesPtr();
ASSERT_EQ(cat_ptr.size(), 3ul);
ASSERT_EQ(cat_ptr[0].beg, 0ul);
ASSERT_EQ(cat_ptr[0].size, 2ul);
auto loaded_categories = loaded_tree.GetSplitCategories();
auto loaded_root = loaded_categories.subspan(cat_ptr[0].beg, cat_ptr[0].size);
ASSERT_TRUE(std::equal(loaded_root.begin(), loaded_root.end(), split_cats.begin()));
}
}
namespace {
RegTree ConstructTree() {
RegTree tree;
tree.ExpandNode(
@@ -123,6 +169,7 @@ RegTree ConstructTree() {
/*right_sum=*/0.0f);
return tree;
}
} // anonymous namespace
TEST(Tree, DumpJson) {
auto tree = ConstructTree();
@@ -133,14 +180,14 @@ TEST(Tree, DumpJson) {
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
ASSERT_EQ(n_leaves, 4ul);
size_t n_conditions = 0;
iter = 0;
while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
n_conditions++;
}
ASSERT_EQ(n_conditions, 3);
ASSERT_EQ(n_conditions, 3ul);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
@@ -156,7 +203,7 @@ TEST(Tree, DumpJson) {
auto j_tree = Json::Load({str.c_str(), str.size()});
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2);
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul);
}
TEST(Tree, DumpText) {
@@ -168,14 +215,14 @@ TEST(Tree, DumpText) {
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
ASSERT_EQ(n_leaves, 4ul);
iter = 0;
size_t n_conditions = 0;
while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
n_conditions++;
}
ASSERT_EQ(n_conditions, 3);
ASSERT_EQ(n_conditions, 3ul);
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
@@ -204,14 +251,14 @@ TEST(Tree, DumpDot) {
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
n_leaves++;
}
ASSERT_EQ(n_leaves, 4);
ASSERT_EQ(n_leaves, 4ul);
size_t n_edges = 0;
iter = 0;
while ((iter = str.find("->", iter + 1)) != std::string::npos) {
n_edges++;
}
ASSERT_EQ(n_edges, 6);
ASSERT_EQ(n_edges, 6ul);
fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(1, "feat_1", "q");
@@ -238,12 +285,12 @@ TEST(Tree, JsonIO) {
ASSERT_EQ(get<String>(tparam["num_nodes"]), "3");
ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0");
ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3);
ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3ul);
RegTree loaded_tree;
loaded_tree.LoadModel(j_tree);
@@ -268,5 +315,4 @@ TEST(Tree, JsonIO) {
ASSERT_EQ(loaded_tree[1].RightChild(), -1);
ASSERT_TRUE(tree.Equal(loaded_tree));
}
} // namespace xgboost