Clarify the behavior of invalid categorical value handling. (#7529)
This commit is contained in:
@@ -5,6 +5,8 @@
|
||||
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
|
||||
#define XGBOOST_COMMON_CATEGORICAL_H_
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "bitfield.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
@@ -30,22 +32,30 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
|
||||
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
|
||||
}
|
||||
|
||||
|
||||
inline XGBOOST_DEVICE bool InvalidCat(float cat) {
|
||||
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
|
||||
}
|
||||
|
||||
/* \brief Whether should it traverse to left branch of a tree.
|
||||
*
|
||||
* For one hot split, go to left if it's NOT the matching category.
|
||||
*/
|
||||
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
|
||||
auto pos = CLBitField32::ToBitPos(cat);
|
||||
if (pos.int_pos >= cats.size()) {
|
||||
return true;
|
||||
}
|
||||
template <bool validate = true>
|
||||
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
|
||||
CLBitField32 const s_cats(cats);
|
||||
return !s_cats.Check(cat);
|
||||
// FIXME: Size() is not accurate since it represents the size of bit set instead of
|
||||
// actual number of categories.
|
||||
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
|
||||
return dft_left;
|
||||
}
|
||||
return !s_cats.Check(AsCat(cat));
|
||||
}
|
||||
|
||||
inline void InvalidCategory() {
|
||||
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
|
||||
"should be non-negative.";
|
||||
"should be non-negative, less than maximum size of int32 and less than total "
|
||||
"number of categories in training data.";
|
||||
}
|
||||
|
||||
/*!
|
||||
@@ -58,9 +68,7 @@ inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task)
|
||||
}
|
||||
|
||||
struct IsCatOp {
|
||||
XGBOOST_DEVICE bool operator()(FeatureType ft) {
|
||||
return ft == FeatureType::kCategorical;
|
||||
}
|
||||
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
|
||||
};
|
||||
|
||||
using CatBitField = LBitField32;
|
||||
|
||||
@@ -581,14 +581,14 @@ void SketchContainer::AllReduce() {
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct InvalidCat {
|
||||
struct InvalidCatOp {
|
||||
Span<float const> values;
|
||||
Span<uint32_t const> ptrs;
|
||||
Span<FeatureType const> ft;
|
||||
|
||||
XGBOOST_DEVICE bool operator()(size_t i) {
|
||||
auto fidx = dh::SegmentId(ptrs, i);
|
||||
return IsCat(ft, fidx) && values[i] < 0;
|
||||
return IsCat(ft, fidx) && InvalidCat(values[i]);
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
@@ -687,10 +687,10 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
|
||||
auto it = thrust::make_counting_iterator(0ul);
|
||||
|
||||
CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
|
||||
auto invalid =
|
||||
thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
|
||||
InvalidCat{out_cut_values, ptrs, d_ft});
|
||||
auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
|
||||
InvalidCatOp{out_cut_values, ptrs, d_ft});
|
||||
if (invalid) {
|
||||
InvalidCategory();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user