[backport] Fix inference with categorical feature. (#8591) (#8602) (#8638)

* Fix inference with categorical feature. (#8591)

* Fix windows build on buildkite. (#8602)

* workaround.
This commit is contained in:
Jiaming Yuan
2023-01-06 01:17:49 +08:00
committed by GitHub
parent 1a834b2b85
commit 067b704e58
7 changed files with 79 additions and 31 deletions

View File

@@ -48,20 +48,21 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat >= kMaxCat;
}
/* \brief Whether should it traverse to left branch of a tree.
/**
* \brief Whether should it traverse to left branch of a tree.
*
* For one hot split, go to left if it's NOT the matching category.
* Go to left if it's NOT the matching category, which matches one-hot encoding.
*/
template <bool validate = true>
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat) {
KCatBitField const s_cats(cats);
// FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
return true;
}
auto pos = KCatBitField::ToBitPos(cat);
// If the input category is larger than the size of the bit field, it implies that the
// category is not chosen. Otherwise the bit field would have the category instead of
// being smaller than the category value.
if (pos.int_pos >= cats.size()) {
return true;
}

View File

@@ -144,7 +144,7 @@ class PartitionBuilder {
auto gidx = gidx_calc(ridx);
bool go_left = default_left;
if (gidx > -1) {
go_left = Decision(node_cats, cut_values[gidx], default_left);
go_left = Decision(node_cats, cut_values[gidx]);
}
return go_left;
} else {
@@ -157,7 +157,7 @@ class PartitionBuilder {
bool go_left = default_left;
if (gidx > -1) {
if (is_cat) {
go_left = Decision(node_cats, cut_values[gidx], default_left);
go_left = Decision(node_cats, cut_values[gidx]);
} else {
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
}

View File

@@ -18,9 +18,7 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
? node.LeftChild()
: node.RightChild();
return common::Decision(node_categories, fvalue) ? node.LeftChild() : node.RightChild();
} else {
return node.LeftChild() + !(fvalue < node.SplitCond());
}

View File

@@ -403,8 +403,7 @@ struct GPUHistMakerDevice {
go_left = data.split_node.DefaultLeft();
} else {
if (data.split_type == FeatureType::kCategorical) {
go_left = common::Decision<false>(data.node_cats.Bits(), cut_value,
data.split_node.DefaultLeft());
go_left = common::Decision(data.node_cats.Bits(), cut_value);
} else {
go_left = cut_value <= data.split_node.SplitCond();
}
@@ -481,7 +480,7 @@ struct GPUHistMakerDevice {
if (common::IsCat(d_feature_types, position)) {
auto node_cats = categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
go_left = common::Decision(node_cats, element);
} else {
go_left = element <= node.SplitCond();
}