Clarify the behavior of invalid categorical value handling. (#7529)

This commit is contained in:
Jiaming Yuan
2022-01-13 16:11:52 +08:00
committed by GitHub
parent 20c0d60ac7
commit e5e47c3c99
7 changed files with 88 additions and 25 deletions

View File

@@ -95,7 +95,7 @@ class ApproxRowPartitioner {
auto node_cats = categories.subspan(segment.beg, segment.size);
bool go_left = true;
if (is_cat) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
} else {
go_left = cut_value <= candidate.split.split_value;
}

View File

@@ -396,7 +396,7 @@ struct GPUHistMakerDevice {
} else {
bool go_left = true;
if (split_type == FeatureType::kCategorical) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision<false>(node_cats, cut_value, split_node.DefaultLeft());
} else {
go_left = cut_value <= split_node.SplitCond();
}
@@ -474,7 +474,7 @@ struct GPUHistMakerDevice {
auto node_cats =
categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision(node_cats, common::AsCat(element));
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
} else {
go_left = element <= node.SplitCond();
}
@@ -573,7 +573,7 @@ struct GPUHistMakerDevice {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
auto cat = common::AsCat(candidate.split.fvalue);
if (cat < 0) {
if (common::InvalidCat(cat)) {
common::InvalidCategory();
}
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);