Validate out of range categorical value. (#7576)

* Use float in CPU categorical set to preserve the input value.
* Check out of range values.
This commit is contained in:
Jiaming Yuan
2022-01-18 20:16:19 +08:00
committed by GitHub
parent d6ea5cc1ed
commit deab0e32ba
8 changed files with 86 additions and 38 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2020-2021 by XGBoost Contributors
* Copyright 2020-2022 by XGBoost Contributors
* \file categorical.h
*/
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
@@ -32,9 +32,17 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
}
constexpr inline bst_cat_t OutOfRangeCat() {
// See the round trip assert in `InvalidCat`.
return static_cast<bst_cat_t>(16777217) - static_cast<bst_cat_t>(1);
}
inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
constexpr auto kMaxCat = OutOfRangeCat();
static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat)) == kMaxCat, "");
static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat + 1)) != kMaxCat + 1, "");
static_assert(static_cast<float>(kMaxCat + 1) == kMaxCat, "");
return cat < 0 || cat >= kMaxCat;
}
/* \brief Whether should it traverse to left branch of a tree.
@@ -53,9 +61,13 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat
}
inline void InvalidCategory() {
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
"should be non-negative, less than maximum size of int32 and less than total "
"number of categories in training data.";
// OutOfRangeCat() can be accurately represented, but everything after it will be
// rounded toward it, so we use >= for comparison check. As a result, we require input
// values to be less than this last representable value.
auto str = std::to_string(OutOfRangeCat());
LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, "
"less than total umber of categories in training data and less than " +
str;
}
/*!

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2020-2021 by XGBoost Contributors
* Copyright 2020-2022 by XGBoost Contributors
*/
#include <limits>
#include <utility>
@@ -27,6 +27,7 @@ SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> column
sketches_.resize(columns_size_.size());
CHECK_GE(n_threads_, 1);
categories_.resize(columns_size_.size());
has_categorical_ = std::any_of(feature_types_.cbegin(), feature_types_.cend(), IsCatOp{});
}
template <typename WQSketch>
@@ -187,7 +188,7 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
if (IsCat(feature_types_, ii)) {
categories_[ii].emplace(AsCat(p_inst[ii].fvalue));
categories_[ii].emplace(p_inst[ii].fvalue);
} else {
sketches_[ii].Push(p_inst[ii].fvalue, w);
}
@@ -197,7 +198,7 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
auto const& entry = p_inst[i];
if (entry.index >= begin && entry.index < end) {
if (IsCat(feature_types_, entry.index)) {
categories_[entry.index].emplace(AsCat(entry.fvalue));
categories_[entry.index].emplace(entry.fvalue);
} else {
sketches_[entry.index].Push(entry.fvalue, w);
}
@@ -352,10 +353,10 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
}
}
void AddCategories(std::set<bst_cat_t> const &categories, HistogramCuts *cuts) {
void AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
auto &cut_values = cuts->cut_values_.HostVector();
for (auto const &v : categories) {
cut_values.push_back(v);
cut_values.push_back(AsCat(v));
}
}
@@ -410,6 +411,15 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back());
cuts->cut_ptrs_.HostVector().push_back(cut_size);
}
if (has_categorical_) {
for (auto const &feat : categories_) {
if (std::any_of(feat.cbegin(), feat.cend(), InvalidCat)) {
InvalidCategory();
}
}
}
monitor_.Stop(__func__);
}
@@ -457,7 +467,7 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
// second pass
if (IsCat(feature_types_, fidx)) {
for (auto c : column) {
categories_[fidx].emplace(AsCat(c.fvalue));
categories_[fidx].emplace(c.fvalue);
}
} else {
for (auto c : column) {

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2021 by Contributors
* Copyright 2014-2022 by XGBoost Contributors
* \file quantile.h
* \brief util to compute quantiles
* \author Tianqi Chen
@@ -706,13 +706,14 @@ template <typename WQSketch>
class SketchContainerImpl {
protected:
std::vector<WQSketch> sketches_;
std::vector<std::set<bst_cat_t>> categories_;
std::vector<std::set<float>> categories_;
std::vector<FeatureType> const feature_types_;
std::vector<bst_row_t> columns_size_;
int32_t max_bins_;
bool use_group_ind_{false};
int32_t n_threads_;
bool has_categorical_{false};
Monitor monitor_;
public:

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2021 by XGBoost Contributors
* Copyright 2021-2022 by XGBoost Contributors
*/
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
@@ -303,8 +303,9 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
if (candidate.split.is_cat) {
std::vector<uint32_t> split_cats;
if (candidate.split.cat_bits.empty()) {
CHECK_LT(candidate.split.split_value, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
if (common::InvalidCat(candidate.split.split_value)) {
common::InvalidCategory();
}
auto cat = common::AsCat(candidate.split.split_value);
split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
LBitField32 cat_bits;

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2017-2021 XGBoost contributors
* Copyright 2017-2022 XGBoost contributors
*/
#include <thrust/copy.h>
#include <thrust/reduce.h>
@@ -572,11 +572,11 @@ struct GPUHistMakerDevice {
if (is_cat) {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
auto cat = common::AsCat(candidate.split.fvalue);
if (common::InvalidCat(cat)) {
if (common::InvalidCat(candidate.split.fvalue)) {
common::InvalidCategory();
}
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
auto cat = common::AsCat(candidate.split.fvalue);
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
LBitField32 cats_bits(split_cats);
cats_bits.Set(cat);
dh::CopyToD(split_cats, &node_categories);