@@ -801,7 +801,8 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
|
||||
? 0
|
||||
: common::KCatBitField::ComputeStorageSize(max_cat);
|
||||
std::vector<uint32_t> cat_bits_storage(size);
|
||||
size = size == 0 ? 1 : size;
|
||||
std::vector<uint32_t> cat_bits_storage(size, 0);
|
||||
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
cat_bits.Set(common::AsCat(get<Integer const>(categories[j])));
|
||||
@@ -818,7 +819,7 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
if (cnt == categories_nodes.size()) {
|
||||
last_cat_node = -1;
|
||||
} else {
|
||||
last_cat_node = get<Integer const>(categories_nodes[++cnt]);
|
||||
last_cat_node = get<Integer const>(categories_nodes[cnt]);
|
||||
}
|
||||
} else {
|
||||
split_categories_segments_[nidx].beg = categories.size();
|
||||
|
||||
Reference in New Issue
Block a user