Add high level tests for categorical data. (#6179)

* Fix unique.
This commit is contained in:
Jiaming Yuan
2020-10-09 09:27:23 +08:00
committed by GitHub
parent 6bc9747df5
commit 70ce5216b5
4 changed files with 78 additions and 21 deletions

View File

@@ -801,7 +801,8 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
? 0
: common::KCatBitField::ComputeStorageSize(max_cat);
std::vector<uint32_t> cat_bits_storage(size);
size = size == 0 ? 1 : size;
std::vector<uint32_t> cat_bits_storage(size, 0);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto j = j_begin; j < j_end; ++j) {
cat_bits.Set(common::AsCat(get<Integer const>(categories[j])));
@@ -818,7 +819,7 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
if (cnt == categories_nodes.size()) {
last_cat_node = -1;
} else {
last_cat_node = get<Integer const>(categories_nodes[++cnt]);
last_cat_node = get<Integer const>(categories_nodes[cnt]);
}
} else {
split_categories_segments_[nidx].beg = categories.size();