Clarify the behavior of invalid categorical value handling. (#7529)

This commit is contained in:
Jiaming Yuan 2022-01-13 16:11:52 +08:00 committed by GitHub
parent 20c0d60ac7
commit e5e47c3c99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 88 additions and 25 deletions

View File

@ -108,6 +108,18 @@ feature it's specified as ``"c"``. The Dask module in XGBoost has the same inte
:class:`dask.Array <dask.Array>` can also be used as categorical data. :class:`dask.Array <dask.Array>` can also be used as categorical data.
*************
Miscellaneous
*************
By default, XGBoost assumes input categories are integers starting from 0 till the number
of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid
values due to mistakes or missing values. It can be negative value, floating point value
that can not be represented by 32-bit integer, or values that are larger than actual
number of unique categories. During training this is validated but for prediction it's
treated as the same as missing value for performance reasons. Lastly, missing values are
treated as the same as numerical features.
********** **********
Next Steps Next Steps
********** **********

View File

@ -5,6 +5,8 @@
#ifndef XGBOOST_COMMON_CATEGORICAL_H_ #ifndef XGBOOST_COMMON_CATEGORICAL_H_
#define XGBOOST_COMMON_CATEGORICAL_H_ #define XGBOOST_COMMON_CATEGORICAL_H_
#include <limits>
#include "bitfield.h" #include "bitfield.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"
@ -30,22 +32,30 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
return !ft.empty() && ft[fidx] == FeatureType::kCategorical; return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
} }
inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
}
/* \brief Whether should it traverse to left branch of a tree. /* \brief Whether should it traverse to left branch of a tree.
* *
* For one hot split, go to left if it's NOT the matching category. * For one hot split, go to left if it's NOT the matching category.
*/ */
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) { template <bool validate = true>
auto pos = CLBitField32::ToBitPos(cat); inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
if (pos.int_pos >= cats.size()) {
return true;
}
CLBitField32 const s_cats(cats); CLBitField32 const s_cats(cats);
return !s_cats.Check(cat); // FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
}
return !s_cats.Check(AsCat(cat));
} }
inline void InvalidCategory() { inline void InvalidCategory() {
LOG(FATAL) << "Invalid categorical value detected. Categorical value " LOG(FATAL) << "Invalid categorical value detected. Categorical value "
"should be non-negative."; "should be non-negative, less than maximum size of int32 and less than total "
"number of categories in training data.";
} }
/*! /*!
@ -58,9 +68,7 @@ inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task)
} }
struct IsCatOp { struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) { XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
return ft == FeatureType::kCategorical;
}
}; };
using CatBitField = LBitField32; using CatBitField = LBitField32;

View File

@ -581,14 +581,14 @@ void SketchContainer::AllReduce() {
} }
namespace { namespace {
struct InvalidCat { struct InvalidCatOp {
Span<float const> values; Span<float const> values;
Span<uint32_t const> ptrs; Span<uint32_t const> ptrs;
Span<FeatureType const> ft; Span<FeatureType const> ft;
XGBOOST_DEVICE bool operator()(size_t i) { XGBOOST_DEVICE bool operator()(size_t i) {
auto fidx = dh::SegmentId(ptrs, i); auto fidx = dh::SegmentId(ptrs, i);
return IsCat(ft, fidx) && values[i] < 0; return IsCat(ft, fidx) && InvalidCat(values[i]);
} }
}; };
} // anonymous namespace } // anonymous namespace
@ -687,10 +687,10 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan(); auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
auto it = thrust::make_counting_iterator(0ul); auto it = thrust::make_counting_iterator(0ul);
CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size()); CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
auto invalid = auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(), InvalidCatOp{out_cut_values, ptrs, d_ft});
InvalidCat{out_cut_values, ptrs, d_ft});
if (invalid) { if (invalid) {
InvalidCategory(); InvalidCategory();
} }

View File

@ -9,16 +9,16 @@
namespace xgboost { namespace xgboost {
namespace predictor { namespace predictor {
template <bool has_missing, bool has_categorical> template <bool has_missing, bool has_categorical>
inline XGBOOST_DEVICE bst_node_t inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue, float fvalue, bool is_missing,
bool is_missing, RegTree::CategoricalSplitMatrix const &cats) { RegTree::CategoricalSplitMatrix const &cats) {
if (has_missing && is_missing) { if (has_missing && is_missing) {
return node.DefaultChild(); return node.DefaultChild();
} else { } else {
if (has_categorical && common::IsCat(cats.split_type, nid)) { if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg, auto node_categories =
cats.node_ptr[nid].size); cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return Decision(node_categories, common::AsCat(fvalue)) return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
? node.LeftChild() ? node.LeftChild()
: node.RightChild(); : node.RightChild();
} else { } else {

View File

@ -95,7 +95,7 @@ class ApproxRowPartitioner {
auto node_cats = categories.subspan(segment.beg, segment.size); auto node_cats = categories.subspan(segment.beg, segment.size);
bool go_left = true; bool go_left = true;
if (is_cat) { if (is_cat) {
go_left = common::Decision(node_cats, common::AsCat(cut_value)); go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
} else { } else {
go_left = cut_value <= candidate.split.split_value; go_left = cut_value <= candidate.split.split_value;
} }

View File

@ -396,7 +396,7 @@ struct GPUHistMakerDevice {
} else { } else {
bool go_left = true; bool go_left = true;
if (split_type == FeatureType::kCategorical) { if (split_type == FeatureType::kCategorical) {
go_left = common::Decision(node_cats, common::AsCat(cut_value)); go_left = common::Decision<false>(node_cats, cut_value, split_node.DefaultLeft());
} else { } else {
go_left = cut_value <= split_node.SplitCond(); go_left = cut_value <= split_node.SplitCond();
} }
@ -474,7 +474,7 @@ struct GPUHistMakerDevice {
auto node_cats = auto node_cats =
categories.subspan(categories_segments[position].beg, categories.subspan(categories_segments[position].beg,
categories_segments[position].size); categories_segments[position].size);
go_left = common::Decision(node_cats, common::AsCat(element)); go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
} else { } else {
go_left = element <= node.SplitCond(); go_left = element <= node.SplitCond();
} }
@ -573,7 +573,7 @@ struct GPUHistMakerDevice {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max()) CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large."; << "Categorical feature value too large.";
auto cat = common::AsCat(candidate.split.fvalue); auto cat = common::AsCat(candidate.split.fvalue);
if (cat < 0) { if (common::InvalidCat(cat)) {
common::InvalidCategory(); common::InvalidCategory();
} }
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0); std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);

View File

@ -0,0 +1,43 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <limits>
#include "../../../src/common/categorical.h"
namespace xgboost {
namespace common {
TEST(Categorical, Decision) {
// inf
float a = std::numeric_limits<float>::infinity();
ASSERT_TRUE(common::InvalidCat(a));
std::vector<uint32_t> cats(256, 0);
ASSERT_TRUE(Decision(cats, a, true));
// larger than size
a = 256;
ASSERT_TRUE(Decision(cats, a, true));
// negative
a = -1;
ASSERT_TRUE(Decision(cats, a, true));
CatBitField bits{cats};
bits.Set(0);
a = -0.5;
ASSERT_TRUE(Decision(cats, a, true));
// round toward 0
a = 0.5;
ASSERT_FALSE(Decision(cats, a, true));
// valid
a = 13;
bits.Set(a);
ASSERT_FALSE(Decision(bits.Bits(), a, true));
}
} // namespace common
} // namespace xgboost