Expand categorical node. (#6028)
Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
// Copyright by Contributors
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
#include "../helpers.h"
|
||||
#include "dmlc/filesystem.h"
|
||||
#include "xgboost/json_io.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "../../../src/common/bitfield.h"
|
||||
|
||||
namespace xgboost {
|
||||
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
|
||||
@@ -82,7 +83,7 @@ TEST(Tree, Load) {
|
||||
tree.Load(fi.get());
|
||||
EXPECT_EQ(tree.GetDepth(1), 1);
|
||||
EXPECT_EQ(tree[0].SplitCond(), 0.5f);
|
||||
EXPECT_EQ(tree[0].SplitIndex(), 5);
|
||||
EXPECT_EQ(tree[0].SplitIndex(), 5ul);
|
||||
EXPECT_EQ(tree[1].LeafValue(), 0.1f);
|
||||
EXPECT_TRUE(tree[1].IsLeaf());
|
||||
}
|
||||
@@ -105,6 +106,51 @@ TEST(Tree, AllocateNode) {
|
||||
ASSERT_TRUE(nodes.at(2).IsLeaf());
|
||||
}
|
||||
|
||||
TEST(Tree, ExpandCategoricalFeature) {
|
||||
{
|
||||
RegTree tree;
|
||||
tree.ExpandCategorical(0, 0, {}, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||
ASSERT_EQ(tree.GetNodes().size(), 3ul);
|
||||
ASSERT_EQ(tree.GetNumLeaves(), 2);
|
||||
ASSERT_EQ(tree.GetSplitTypes().size(), 3ul);
|
||||
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
|
||||
ASSERT_EQ(tree.GetSplitTypes()[1], FeatureType::kNumerical);
|
||||
ASSERT_EQ(tree.GetSplitTypes()[2], FeatureType::kNumerical);
|
||||
ASSERT_EQ(tree.GetSplitCategories().size(), 0ul);
|
||||
ASSERT_TRUE(std::isnan(tree[0].SplitCond()));
|
||||
}
|
||||
{
|
||||
RegTree tree;
|
||||
bst_cat_t cat = 33;
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(cat+1));
|
||||
LBitField32 bitset {split_cats};
|
||||
bitset.Set(cat);
|
||||
tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||
auto categories = tree.GetSplitCategories();
|
||||
auto segments = tree.GetSplitCategoriesPtr();
|
||||
auto got = categories.subspan(segments[0].beg, segments[0].size);
|
||||
ASSERT_TRUE(std::equal(got.cbegin(), got.cend(), split_cats.cbegin()));
|
||||
|
||||
Json out{Object()};
|
||||
tree.SaveModel(&out);
|
||||
|
||||
RegTree loaded_tree;
|
||||
loaded_tree.LoadModel(out);
|
||||
|
||||
auto const& cat_ptr = loaded_tree.GetSplitCategoriesPtr();
|
||||
ASSERT_EQ(cat_ptr.size(), 3ul);
|
||||
ASSERT_EQ(cat_ptr[0].beg, 0ul);
|
||||
ASSERT_EQ(cat_ptr[0].size, 2ul);
|
||||
|
||||
auto loaded_categories = loaded_tree.GetSplitCategories();
|
||||
auto loaded_root = loaded_categories.subspan(cat_ptr[0].beg, cat_ptr[0].size);
|
||||
ASSERT_TRUE(std::equal(loaded_root.begin(), loaded_root.end(), split_cats.begin()));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
RegTree ConstructTree() {
|
||||
RegTree tree;
|
||||
tree.ExpandNode(
|
||||
@@ -123,6 +169,7 @@ RegTree ConstructTree() {
|
||||
/*right_sum=*/0.0f);
|
||||
return tree;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Tree, DumpJson) {
|
||||
auto tree = ConstructTree();
|
||||
@@ -133,14 +180,14 @@ TEST(Tree, DumpJson) {
|
||||
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||
n_leaves++;
|
||||
}
|
||||
ASSERT_EQ(n_leaves, 4);
|
||||
ASSERT_EQ(n_leaves, 4ul);
|
||||
|
||||
size_t n_conditions = 0;
|
||||
iter = 0;
|
||||
while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
|
||||
n_conditions++;
|
||||
}
|
||||
ASSERT_EQ(n_conditions, 3);
|
||||
ASSERT_EQ(n_conditions, 3ul);
|
||||
|
||||
fmap.PushBack(0, "feat_0", "i");
|
||||
fmap.PushBack(1, "feat_1", "q");
|
||||
@@ -156,7 +203,7 @@ TEST(Tree, DumpJson) {
|
||||
|
||||
|
||||
auto j_tree = Json::Load({str.c_str(), str.size()});
|
||||
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2);
|
||||
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul);
|
||||
}
|
||||
|
||||
TEST(Tree, DumpText) {
|
||||
@@ -168,14 +215,14 @@ TEST(Tree, DumpText) {
|
||||
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||
n_leaves++;
|
||||
}
|
||||
ASSERT_EQ(n_leaves, 4);
|
||||
ASSERT_EQ(n_leaves, 4ul);
|
||||
|
||||
iter = 0;
|
||||
size_t n_conditions = 0;
|
||||
while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
|
||||
n_conditions++;
|
||||
}
|
||||
ASSERT_EQ(n_conditions, 3);
|
||||
ASSERT_EQ(n_conditions, 3ul);
|
||||
|
||||
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
|
||||
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
|
||||
@@ -204,14 +251,14 @@ TEST(Tree, DumpDot) {
|
||||
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||
n_leaves++;
|
||||
}
|
||||
ASSERT_EQ(n_leaves, 4);
|
||||
ASSERT_EQ(n_leaves, 4ul);
|
||||
|
||||
size_t n_edges = 0;
|
||||
iter = 0;
|
||||
while ((iter = str.find("->", iter + 1)) != std::string::npos) {
|
||||
n_edges++;
|
||||
}
|
||||
ASSERT_EQ(n_edges, 6);
|
||||
ASSERT_EQ(n_edges, 6ul);
|
||||
|
||||
fmap.PushBack(0, "feat_0", "i");
|
||||
fmap.PushBack(1, "feat_1", "q");
|
||||
@@ -238,12 +285,12 @@ TEST(Tree, JsonIO) {
|
||||
ASSERT_EQ(get<String>(tparam["num_nodes"]), "3");
|
||||
ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0");
|
||||
|
||||
ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3);
|
||||
ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3ul);
|
||||
ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3ul);
|
||||
ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3ul);
|
||||
ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3ul);
|
||||
ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3ul);
|
||||
ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3ul);
|
||||
|
||||
RegTree loaded_tree;
|
||||
loaded_tree.LoadModel(j_tree);
|
||||
@@ -268,5 +315,4 @@ TEST(Tree, JsonIO) {
|
||||
ASSERT_EQ(loaded_tree[1].RightChild(), -1);
|
||||
ASSERT_TRUE(tree.Equal(loaded_tree));
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user