Validate out of range categorical value. (#7576)
* Use float in CPU categorical set to preserve the input value. * Check out of range values.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020-2021 by XGBoost Contributors
|
||||
* Copyright 2020-2022 by XGBoost Contributors
|
||||
* \file categorical.h
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
|
||||
@@ -32,9 +32,17 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
|
||||
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
|
||||
}
|
||||
|
||||
constexpr inline bst_cat_t OutOfRangeCat() {
|
||||
// See the round trip assert in `InvalidCat`.
|
||||
return static_cast<bst_cat_t>(16777217) - static_cast<bst_cat_t>(1);
|
||||
}
|
||||
|
||||
inline XGBOOST_DEVICE bool InvalidCat(float cat) {
|
||||
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
|
||||
constexpr auto kMaxCat = OutOfRangeCat();
|
||||
static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat)) == kMaxCat, "");
|
||||
static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat + 1)) != kMaxCat + 1, "");
|
||||
static_assert(static_cast<float>(kMaxCat + 1) == kMaxCat, "");
|
||||
return cat < 0 || cat >= kMaxCat;
|
||||
}
|
||||
|
||||
/* \brief Whether should it traverse to left branch of a tree.
|
||||
@@ -53,9 +61,13 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat
|
||||
}
|
||||
|
||||
inline void InvalidCategory() {
|
||||
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
|
||||
"should be non-negative, less than maximum size of int32 and less than total "
|
||||
"number of categories in training data.";
|
||||
// OutOfRangeCat() can be accurately represented, but everything after it will be
|
||||
// rounded toward it, so we use >= for comparison check. As a result, we require input
|
||||
// values to be less than this last representable value.
|
||||
auto str = std::to_string(OutOfRangeCat());
|
||||
LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, "
|
||||
"less than total umber of categories in training data and less than " +
|
||||
str;
|
||||
}
|
||||
|
||||
/*!
|
||||
|
||||
Reference in New Issue
Block a user