Fix inference with categorical feature. (#8591)

This commit is contained in:
Jiaming Yuan 2022-12-15 17:57:26 +08:00 committed by GitHub
parent 7dc3e95a77
commit 43a647a4dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 28 deletions

View File

@ -138,11 +138,11 @@ Miscellaneous
By default, XGBoost assumes input categories are integers starting from 0 till the number By default, XGBoost assumes input categories are integers starting from 0 till the number
of categories :math:`[0, n\_categories)`. However, user might provide inputs with invalid of categories :math:`[0, n\_categories)`. However, user might provide inputs with invalid
values due to mistakes or missing values. It can be negative value, integer values that values due to mistakes or missing values in training dataset. It can be negative value,
can not be accurately represented by 32-bit floating point, or values that are larger than integer values that can not be accurately represented by 32-bit floating point, or values
actual number of unique categories. During training this is validated but for prediction that are larger than actual number of unique categories. During training this is
it's treated as the same as missing value for performance reasons. Lastly, missing values validated but for prediction it's treated as the same as not-chosen category for
are treated as the same as numerical features (using the learned split direction). performance reasons.
********** **********

View File

@ -48,20 +48,21 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat >= kMaxCat; return cat < 0 || cat >= kMaxCat;
} }
/* \brief Whether should it traverse to left branch of a tree. /**
* \brief Whether should it traverse to left branch of a tree.
* *
* For one hot split, go to left if it's NOT the matching category. * Go to left if it's NOT the matching category, which matches one-hot encoding.
*/ */
template <bool validate = true> inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat) {
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
KCatBitField const s_cats(cats); KCatBitField const s_cats(cats);
// FIXME: Size() is not accurate since it represents the size of bit set instead of if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
// actual number of categories. return true;
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
} }
auto pos = KCatBitField::ToBitPos(cat); auto pos = KCatBitField::ToBitPos(cat);
// If the input category is larger than the size of the bit field, it implies that the
// category is not chosen. Otherwise the bit field would have the category instead of
// being smaller than the category value.
if (pos.int_pos >= cats.size()) { if (pos.int_pos >= cats.size()) {
return true; return true;
} }

View File

@ -144,7 +144,7 @@ class PartitionBuilder {
auto gidx = gidx_calc(ridx); auto gidx = gidx_calc(ridx);
bool go_left = default_left; bool go_left = default_left;
if (gidx > -1) { if (gidx > -1) {
go_left = Decision(node_cats, cut_values[gidx], default_left); go_left = Decision(node_cats, cut_values[gidx]);
} }
return go_left; return go_left;
} else { } else {
@ -157,7 +157,7 @@ class PartitionBuilder {
bool go_left = default_left; bool go_left = default_left;
if (gidx > -1) { if (gidx > -1) {
if (is_cat) { if (is_cat) {
go_left = Decision(node_cats, cut_values[gidx], default_left); go_left = Decision(node_cats, cut_values[gidx]);
} else { } else {
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value; go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
} }

View File

@ -18,9 +18,7 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs
if (has_categorical && common::IsCat(cats.split_type, nid)) { if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories = auto node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size); cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft()) return common::Decision(node_categories, fvalue) ? node.LeftChild() : node.RightChild();
? node.LeftChild()
: node.RightChild();
} else { } else {
return node.LeftChild() + !(fvalue < node.SplitCond()); return node.LeftChild() + !(fvalue < node.SplitCond());
} }

View File

@ -402,8 +402,7 @@ struct GPUHistMakerDevice {
go_left = data.split_node.DefaultLeft(); go_left = data.split_node.DefaultLeft();
} else { } else {
if (data.split_type == FeatureType::kCategorical) { if (data.split_type == FeatureType::kCategorical) {
go_left = common::Decision<false>(data.node_cats.Bits(), cut_value, go_left = common::Decision(data.node_cats.Bits(), cut_value);
data.split_node.DefaultLeft());
} else { } else {
go_left = cut_value <= data.split_node.SplitCond(); go_left = cut_value <= data.split_node.SplitCond();
} }
@ -480,7 +479,7 @@ struct GPUHistMakerDevice {
if (common::IsCat(d_feature_types, position)) { if (common::IsCat(d_feature_types, position)) {
auto node_cats = categories.subspan(categories_segments[position].beg, auto node_cats = categories.subspan(categories_segments[position].beg,
categories_segments[position].size); categories_segments[position].size);
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft()); go_left = common::Decision(node_cats, element);
} else { } else {
go_left = element <= node.SplitCond(); go_left = element <= node.SplitCond();
} }

View File

@ -1,11 +1,14 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2022 by XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/json.h>
#include <xgboost/learner.h>
#include <limits> #include <limits>
#include "../../../src/common/categorical.h" #include "../../../src/common/categorical.h"
#include "../helpers.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -15,29 +18,75 @@ TEST(Categorical, Decision) {
ASSERT_TRUE(common::InvalidCat(a)); ASSERT_TRUE(common::InvalidCat(a));
std::vector<uint32_t> cats(256, 0); std::vector<uint32_t> cats(256, 0);
ASSERT_TRUE(Decision(cats, a, true)); ASSERT_TRUE(Decision(cats, a));
// larger than size // larger than size
a = 256; a = 256;
ASSERT_TRUE(Decision(cats, a, true)); ASSERT_TRUE(Decision(cats, a));
// negative // negative
a = -1; a = -1;
ASSERT_TRUE(Decision(cats, a, true)); ASSERT_TRUE(Decision(cats, a));
CatBitField bits{cats}; CatBitField bits{cats};
bits.Set(0); bits.Set(0);
a = -0.5; a = -0.5;
ASSERT_TRUE(Decision(cats, a, true)); ASSERT_TRUE(Decision(cats, a));
// round toward 0 // round toward 0
a = 0.5; a = 0.5;
ASSERT_FALSE(Decision(cats, a, true)); ASSERT_FALSE(Decision(cats, a));
// valid // valid
a = 13; a = 13;
bits.Set(a); bits.Set(a);
ASSERT_FALSE(Decision(bits.Bits(), a, true)); ASSERT_FALSE(Decision(bits.Bits(), a));
}
/**
* Test for running inference with input category greater than the one stored in tree.
*/
TEST(Categorical, MinimalSet) {
std::size_t constexpr kRows = 256, kCols = 1, kCat = 3;
std::vector<FeatureType> types{FeatureType::kCategorical};
auto Xy =
RandomDataGenerator{kRows, kCols, 0.0}.Type(types).MaxCategory(kCat).GenerateDMatrix(true);
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->SetParam("max_depth", "1");
learner->SetParam("tree_method", "hist");
learner->Configure();
learner->UpdateOneIter(0, Xy);
Json model{Object{}};
learner->SaveModel(&model);
auto tree = model["learner"]["gradient_booster"]["model"]["trees"][0];
ASSERT_GE(get<I32Array const>(tree["categories"]).size(), 1);
auto v = get<I32Array const>(tree["categories"])[0];
HostDeviceVector<float> predt;
{
std::vector<float> data{kCat, kCat + 1, 32, 33, 34};
auto test = GetDMatrixFromData(data, data.size(), kCols);
learner->Predict(test, false, &predt, 0, 0, false, /*pred_leaf=*/true);
ASSERT_EQ(predt.Size(), data.size());
auto const& h_predt = predt.ConstHostSpan();
for (auto v : h_predt) {
ASSERT_EQ(v, 1); // left child of root node
}
}
{
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->LoadModel(model);
std::vector<float> data = {static_cast<float>(v)};
auto test = GetDMatrixFromData(data, data.size(), kCols);
learner->Predict(test, false, &predt, 0, 0, false, /*pred_leaf=*/true);
auto const& h_predt = predt.ConstHostSpan();
for (auto v : h_predt) {
ASSERT_EQ(v, 2); // right child of root node
}
}
} }
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost