Fix inference with categorical feature. (#8591)
This commit is contained in:
parent
7dc3e95a77
commit
43a647a4dd
@ -138,11 +138,11 @@ Miscellaneous
|
|||||||
|
|
||||||
By default, XGBoost assumes input categories are integers starting from 0 till the number
|
By default, XGBoost assumes input categories are integers starting from 0 till the number
|
||||||
of categories :math:`[0, n\_categories)`. However, user might provide inputs with invalid
|
of categories :math:`[0, n\_categories)`. However, user might provide inputs with invalid
|
||||||
values due to mistakes or missing values. It can be negative value, integer values that
|
values due to mistakes or missing values in training dataset. It can be negative value,
|
||||||
can not be accurately represented by 32-bit floating point, or values that are larger than
|
integer values that can not be accurately represented by 32-bit floating point, or values
|
||||||
actual number of unique categories. During training this is validated but for prediction
|
that are larger than actual number of unique categories. During training this is
|
||||||
it's treated as the same as missing value for performance reasons. Lastly, missing values
|
validated but for prediction it's treated as the same as not-chosen category for
|
||||||
are treated as the same as numerical features (using the learned split direction).
|
performance reasons.
|
||||||
|
|
||||||
|
|
||||||
**********
|
**********
|
||||||
|
|||||||
@ -48,20 +48,21 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
|
|||||||
return cat < 0 || cat >= kMaxCat;
|
return cat < 0 || cat >= kMaxCat;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* \brief Whether should it traverse to left branch of a tree.
|
/**
|
||||||
|
* \brief Whether should it traverse to left branch of a tree.
|
||||||
*
|
*
|
||||||
* For one hot split, go to left if it's NOT the matching category.
|
* Go to left if it's NOT the matching category, which matches one-hot encoding.
|
||||||
*/
|
*/
|
||||||
template <bool validate = true>
|
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat) {
|
||||||
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
|
|
||||||
KCatBitField const s_cats(cats);
|
KCatBitField const s_cats(cats);
|
||||||
// FIXME: Size() is not accurate since it represents the size of bit set instead of
|
if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
|
||||||
// actual number of categories.
|
return true;
|
||||||
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
|
|
||||||
return dft_left;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto pos = KCatBitField::ToBitPos(cat);
|
auto pos = KCatBitField::ToBitPos(cat);
|
||||||
|
// If the input category is larger than the size of the bit field, it implies that the
|
||||||
|
// category is not chosen. Otherwise the bit field would have the category instead of
|
||||||
|
// being smaller than the category value.
|
||||||
if (pos.int_pos >= cats.size()) {
|
if (pos.int_pos >= cats.size()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -144,7 +144,7 @@ class PartitionBuilder {
|
|||||||
auto gidx = gidx_calc(ridx);
|
auto gidx = gidx_calc(ridx);
|
||||||
bool go_left = default_left;
|
bool go_left = default_left;
|
||||||
if (gidx > -1) {
|
if (gidx > -1) {
|
||||||
go_left = Decision(node_cats, cut_values[gidx], default_left);
|
go_left = Decision(node_cats, cut_values[gidx]);
|
||||||
}
|
}
|
||||||
return go_left;
|
return go_left;
|
||||||
} else {
|
} else {
|
||||||
@ -157,7 +157,7 @@ class PartitionBuilder {
|
|||||||
bool go_left = default_left;
|
bool go_left = default_left;
|
||||||
if (gidx > -1) {
|
if (gidx > -1) {
|
||||||
if (is_cat) {
|
if (is_cat) {
|
||||||
go_left = Decision(node_cats, cut_values[gidx], default_left);
|
go_left = Decision(node_cats, cut_values[gidx]);
|
||||||
} else {
|
} else {
|
||||||
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
|
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,9 +18,7 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs
|
|||||||
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
||||||
auto node_categories =
|
auto node_categories =
|
||||||
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
|
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
|
||||||
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
|
return common::Decision(node_categories, fvalue) ? node.LeftChild() : node.RightChild();
|
||||||
? node.LeftChild()
|
|
||||||
: node.RightChild();
|
|
||||||
} else {
|
} else {
|
||||||
return node.LeftChild() + !(fvalue < node.SplitCond());
|
return node.LeftChild() + !(fvalue < node.SplitCond());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -402,8 +402,7 @@ struct GPUHistMakerDevice {
|
|||||||
go_left = data.split_node.DefaultLeft();
|
go_left = data.split_node.DefaultLeft();
|
||||||
} else {
|
} else {
|
||||||
if (data.split_type == FeatureType::kCategorical) {
|
if (data.split_type == FeatureType::kCategorical) {
|
||||||
go_left = common::Decision<false>(data.node_cats.Bits(), cut_value,
|
go_left = common::Decision(data.node_cats.Bits(), cut_value);
|
||||||
data.split_node.DefaultLeft());
|
|
||||||
} else {
|
} else {
|
||||||
go_left = cut_value <= data.split_node.SplitCond();
|
go_left = cut_value <= data.split_node.SplitCond();
|
||||||
}
|
}
|
||||||
@ -480,7 +479,7 @@ struct GPUHistMakerDevice {
|
|||||||
if (common::IsCat(d_feature_types, position)) {
|
if (common::IsCat(d_feature_types, position)) {
|
||||||
auto node_cats = categories.subspan(categories_segments[position].beg,
|
auto node_cats = categories.subspan(categories_segments[position].beg,
|
||||||
categories_segments[position].size);
|
categories_segments[position].size);
|
||||||
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
|
go_left = common::Decision(node_cats, element);
|
||||||
} else {
|
} else {
|
||||||
go_left = element <= node.SplitCond();
|
go_left = element <= node.SplitCond();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2021 by XGBoost Contributors
|
* Copyright 2021-2022 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/json.h>
|
||||||
|
#include <xgboost/learner.h>
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "../../../src/common/categorical.h"
|
#include "../../../src/common/categorical.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -15,29 +18,75 @@ TEST(Categorical, Decision) {
|
|||||||
|
|
||||||
ASSERT_TRUE(common::InvalidCat(a));
|
ASSERT_TRUE(common::InvalidCat(a));
|
||||||
std::vector<uint32_t> cats(256, 0);
|
std::vector<uint32_t> cats(256, 0);
|
||||||
ASSERT_TRUE(Decision(cats, a, true));
|
ASSERT_TRUE(Decision(cats, a));
|
||||||
|
|
||||||
// larger than size
|
// larger than size
|
||||||
a = 256;
|
a = 256;
|
||||||
ASSERT_TRUE(Decision(cats, a, true));
|
ASSERT_TRUE(Decision(cats, a));
|
||||||
|
|
||||||
// negative
|
// negative
|
||||||
a = -1;
|
a = -1;
|
||||||
ASSERT_TRUE(Decision(cats, a, true));
|
ASSERT_TRUE(Decision(cats, a));
|
||||||
|
|
||||||
CatBitField bits{cats};
|
CatBitField bits{cats};
|
||||||
bits.Set(0);
|
bits.Set(0);
|
||||||
a = -0.5;
|
a = -0.5;
|
||||||
ASSERT_TRUE(Decision(cats, a, true));
|
ASSERT_TRUE(Decision(cats, a));
|
||||||
|
|
||||||
// round toward 0
|
// round toward 0
|
||||||
a = 0.5;
|
a = 0.5;
|
||||||
ASSERT_FALSE(Decision(cats, a, true));
|
ASSERT_FALSE(Decision(cats, a));
|
||||||
|
|
||||||
// valid
|
// valid
|
||||||
a = 13;
|
a = 13;
|
||||||
bits.Set(a);
|
bits.Set(a);
|
||||||
ASSERT_FALSE(Decision(bits.Bits(), a, true));
|
ASSERT_FALSE(Decision(bits.Bits(), a));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test for running inference with input category greater than the one stored in tree.
|
||||||
|
*/
|
||||||
|
TEST(Categorical, MinimalSet) {
|
||||||
|
std::size_t constexpr kRows = 256, kCols = 1, kCat = 3;
|
||||||
|
std::vector<FeatureType> types{FeatureType::kCategorical};
|
||||||
|
auto Xy =
|
||||||
|
RandomDataGenerator{kRows, kCols, 0.0}.Type(types).MaxCategory(kCat).GenerateDMatrix(true);
|
||||||
|
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||||
|
learner->SetParam("max_depth", "1");
|
||||||
|
learner->SetParam("tree_method", "hist");
|
||||||
|
learner->Configure();
|
||||||
|
learner->UpdateOneIter(0, Xy);
|
||||||
|
|
||||||
|
Json model{Object{}};
|
||||||
|
learner->SaveModel(&model);
|
||||||
|
auto tree = model["learner"]["gradient_booster"]["model"]["trees"][0];
|
||||||
|
ASSERT_GE(get<I32Array const>(tree["categories"]).size(), 1);
|
||||||
|
auto v = get<I32Array const>(tree["categories"])[0];
|
||||||
|
|
||||||
|
HostDeviceVector<float> predt;
|
||||||
|
{
|
||||||
|
std::vector<float> data{kCat, kCat + 1, 32, 33, 34};
|
||||||
|
auto test = GetDMatrixFromData(data, data.size(), kCols);
|
||||||
|
learner->Predict(test, false, &predt, 0, 0, false, /*pred_leaf=*/true);
|
||||||
|
ASSERT_EQ(predt.Size(), data.size());
|
||||||
|
auto const& h_predt = predt.ConstHostSpan();
|
||||||
|
for (auto v : h_predt) {
|
||||||
|
ASSERT_EQ(v, 1); // left child of root node
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||||
|
learner->LoadModel(model);
|
||||||
|
std::vector<float> data = {static_cast<float>(v)};
|
||||||
|
auto test = GetDMatrixFromData(data, data.size(), kCols);
|
||||||
|
learner->Predict(test, false, &predt, 0, 0, false, /*pred_leaf=*/true);
|
||||||
|
auto const& h_predt = predt.ConstHostSpan();
|
||||||
|
for (auto v : h_predt) {
|
||||||
|
ASSERT_EQ(v, 2); // right child of root node
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user