Fix inference with categorical feature. (#8591)

This commit is contained in:
Jiaming Yuan
2022-12-15 17:57:26 +08:00
committed by GitHub
parent 7dc3e95a77
commit 43a647a4dd
6 changed files with 75 additions and 28 deletions

View File

@@ -48,20 +48,21 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat >= kMaxCat;
}
/* \brief Whether should it traverse to left branch of a tree.
/**
* \brief Whether should it traverse to left branch of a tree.
*
* For one hot split, go to left if it's NOT the matching category.
* Go to left if it's NOT the matching category, which matches one-hot encoding.
*/
template <bool validate = true>
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat) {
KCatBitField const s_cats(cats);
// FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
return true;
}
auto pos = KCatBitField::ToBitPos(cat);
// If the input category is larger than the size of the bit field, it implies that the
// category is not chosen. Otherwise the bit field would have the category instead of
// being smaller than the category value.
if (pos.int_pos >= cats.size()) {
return true;
}

View File

@@ -144,7 +144,7 @@ class PartitionBuilder {
auto gidx = gidx_calc(ridx);
bool go_left = default_left;
if (gidx > -1) {
go_left = Decision(node_cats, cut_values[gidx], default_left);
go_left = Decision(node_cats, cut_values[gidx]);
}
return go_left;
} else {
@@ -157,7 +157,7 @@ class PartitionBuilder {
bool go_left = default_left;
if (gidx > -1) {
if (is_cat) {
go_left = Decision(node_cats, cut_values[gidx], default_left);
go_left = Decision(node_cats, cut_values[gidx]);
} else {
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
}