Clarify the behavior of invalid categorical value handling. (#7529)

This commit is contained in:
Jiaming Yuan 2022-01-13 16:11:52 +08:00 committed by GitHub
parent 20c0d60ac7
commit e5e47c3c99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 88 additions and 25 deletions

View File

@ -108,6 +108,18 @@ feature it's specified as ``"c"``. The Dask module in XGBoost has the same inte
:class:`dask.Array <dask.Array>` can also be used as categorical data.
*************
Miscellaneous
*************
By default, XGBoost assumes input categories are integers starting from 0 till the number
of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid
values due to mistakes or missing values. It can be negative value, floating point value
that can not be represented by 32-bit integer, or values that are larger than actual
number of unique categories. During training this is validated but for prediction it's
treated as the same as missing value for performance reasons. Lastly, missing values are
treated as the same as numerical features.
**********
Next Steps
**********

View File

@ -5,6 +5,8 @@
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
#define XGBOOST_COMMON_CATEGORICAL_H_
#include <limits>
#include "bitfield.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
@ -30,22 +32,30 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
}
inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
}
/* \brief Whether should it traverse to left branch of a tree.
*
* For one hot split, go to left if it's NOT the matching category.
*/
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
auto pos = CLBitField32::ToBitPos(cat);
if (pos.int_pos >= cats.size()) {
return true;
}
template <bool validate = true>
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
CLBitField32 const s_cats(cats);
return !s_cats.Check(cat);
// FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
}
return !s_cats.Check(AsCat(cat));
}
inline void InvalidCategory() {
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
"should be non-negative.";
"should be non-negative, less than maximum size of int32 and less than total "
"number of categories in training data.";
}
/*!
@ -58,9 +68,7 @@ inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task)
}
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
};
using CatBitField = LBitField32;

View File

@ -581,14 +581,14 @@ void SketchContainer::AllReduce() {
}
namespace {
struct InvalidCat {
struct InvalidCatOp {
Span<float const> values;
Span<uint32_t const> ptrs;
Span<FeatureType const> ft;
XGBOOST_DEVICE bool operator()(size_t i) {
auto fidx = dh::SegmentId(ptrs, i);
return IsCat(ft, fidx) && values[i] < 0;
return IsCat(ft, fidx) && InvalidCat(values[i]);
}
};
} // anonymous namespace
@ -687,10 +687,10 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
auto it = thrust::make_counting_iterator(0ul);
CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
auto invalid =
thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
InvalidCat{out_cut_values, ptrs, d_ft});
auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
InvalidCatOp{out_cut_values, ptrs, d_ft});
if (invalid) {
InvalidCategory();
}

View File

@ -9,16 +9,16 @@
namespace xgboost {
namespace predictor {
template <bool has_missing, bool has_categorical>
inline XGBOOST_DEVICE bst_node_t
GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue,
bool is_missing, RegTree::CategoricalSplitMatrix const &cats) {
inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
float fvalue, bool is_missing,
RegTree::CategoricalSplitMatrix const &cats) {
if (has_missing && is_missing) {
return node.DefaultChild();
} else {
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg,
cats.node_ptr[nid].size);
return Decision(node_categories, common::AsCat(fvalue))
auto node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
? node.LeftChild()
: node.RightChild();
} else {

View File

@ -95,7 +95,7 @@ class ApproxRowPartitioner {
auto node_cats = categories.subspan(segment.beg, segment.size);
bool go_left = true;
if (is_cat) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
} else {
go_left = cut_value <= candidate.split.split_value;
}

View File

@ -396,7 +396,7 @@ struct GPUHistMakerDevice {
} else {
bool go_left = true;
if (split_type == FeatureType::kCategorical) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision<false>(node_cats, cut_value, split_node.DefaultLeft());
} else {
go_left = cut_value <= split_node.SplitCond();
}
@ -474,7 +474,7 @@ struct GPUHistMakerDevice {
auto node_cats =
categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision(node_cats, common::AsCat(element));
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
} else {
go_left = element <= node.SplitCond();
}
@ -573,7 +573,7 @@ struct GPUHistMakerDevice {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
auto cat = common::AsCat(candidate.split.fvalue);
if (cat < 0) {
if (common::InvalidCat(cat)) {
common::InvalidCategory();
}
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);

View File

@ -0,0 +1,43 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <limits>
#include "../../../src/common/categorical.h"
namespace xgboost {
namespace common {
TEST(Categorical, Decision) {
// inf
float a = std::numeric_limits<float>::infinity();
ASSERT_TRUE(common::InvalidCat(a));
std::vector<uint32_t> cats(256, 0);
ASSERT_TRUE(Decision(cats, a, true));
// larger than size
a = 256;
ASSERT_TRUE(Decision(cats, a, true));
// negative
a = -1;
ASSERT_TRUE(Decision(cats, a, true));
CatBitField bits{cats};
bits.Set(0);
a = -0.5;
ASSERT_TRUE(Decision(cats, a, true));
// round toward 0
a = 0.5;
ASSERT_FALSE(Decision(cats, a, true));
// valid
a = 13;
bits.Set(a);
ASSERT_FALSE(Decision(bits.Bits(), a, true));
}
} // namespace common
} // namespace xgboost