Validate out of range categorical value. (#7576)

* Use float in CPU categorical set to preserve the input value.
* Check out of range values.
This commit is contained in:
Jiaming Yuan 2022-01-18 20:16:19 +08:00 committed by GitHub
parent d6ea5cc1ed
commit deab0e32ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 86 additions and 38 deletions

View File

@ -114,11 +114,11 @@ Miscellaneous
By default, XGBoost assumes input categories are integers starting from 0 till the number By default, XGBoost assumes input categories are integers starting from 0 till the number
of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid
values due to mistakes or missing values. It can be negative value, floating point value values due to mistakes or missing values. It can be negative value, integer values that
that can not be represented by 32-bit integer, or values that are larger than actual can not be accurately represented by 32-bit floating point, or values that are larger than
number of unique categories. During training this is validated but for prediction it's actual number of unique categories. During training this is validated but for prediction
treated as the same as missing value for performance reasons. Lastly, missing values are it's treated as the same as missing value for performance reasons. Lastly, missing values
treated as the same as numerical features. are treated as the same as numerical features (using the learned split direction).
********** **********
Next Steps Next Steps

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2020-2021 by XGBoost Contributors * Copyright 2020-2022 by XGBoost Contributors
* \file categorical.h * \file categorical.h
*/ */
#ifndef XGBOOST_COMMON_CATEGORICAL_H_ #ifndef XGBOOST_COMMON_CATEGORICAL_H_
@ -32,9 +32,17 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
return !ft.empty() && ft[fidx] == FeatureType::kCategorical; return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
} }
constexpr inline bst_cat_t OutOfRangeCat() {
// See the round trip assert in `InvalidCat`.
return static_cast<bst_cat_t>(16777217) - static_cast<bst_cat_t>(1);
}
inline XGBOOST_DEVICE bool InvalidCat(float cat) { inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max()); constexpr auto kMaxCat = OutOfRangeCat();
static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat)) == kMaxCat, "");
static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat + 1)) != kMaxCat + 1, "");
static_assert(static_cast<float>(kMaxCat + 1) == kMaxCat, "");
return cat < 0 || cat >= kMaxCat;
} }
/* \brief Whether should it traverse to left branch of a tree. /* \brief Whether should it traverse to left branch of a tree.
@ -53,9 +61,13 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat
} }
inline void InvalidCategory() { inline void InvalidCategory() {
LOG(FATAL) << "Invalid categorical value detected. Categorical value " // OutOfRangeCat() can be accurately represented, but everything after it will be
"should be non-negative, less than maximum size of int32 and less than total " // rounded toward it, so we use >= for comparison check. As a result, we require input
"number of categories in training data."; // values to be less than this last representable value.
auto str = std::to_string(OutOfRangeCat());
LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, "
"less than total umber of categories in training data and less than " +
str;
} }
/*! /*!

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2020-2021 by XGBoost Contributors * Copyright 2020-2022 by XGBoost Contributors
*/ */
#include <limits> #include <limits>
#include <utility> #include <utility>
@ -27,6 +27,7 @@ SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> column
sketches_.resize(columns_size_.size()); sketches_.resize(columns_size_.size());
CHECK_GE(n_threads_, 1); CHECK_GE(n_threads_, 1);
categories_.resize(columns_size_.size()); categories_.resize(columns_size_.size());
has_categorical_ = std::any_of(feature_types_.cbegin(), feature_types_.cend(), IsCatOp{});
} }
template <typename WQSketch> template <typename WQSketch>
@ -187,7 +188,7 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
if (is_dense) { if (is_dense) {
for (size_t ii = begin; ii < end; ii++) { for (size_t ii = begin; ii < end; ii++) {
if (IsCat(feature_types_, ii)) { if (IsCat(feature_types_, ii)) {
categories_[ii].emplace(AsCat(p_inst[ii].fvalue)); categories_[ii].emplace(p_inst[ii].fvalue);
} else { } else {
sketches_[ii].Push(p_inst[ii].fvalue, w); sketches_[ii].Push(p_inst[ii].fvalue, w);
} }
@ -197,7 +198,7 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
auto const& entry = p_inst[i]; auto const& entry = p_inst[i];
if (entry.index >= begin && entry.index < end) { if (entry.index >= begin && entry.index < end) {
if (IsCat(feature_types_, entry.index)) { if (IsCat(feature_types_, entry.index)) {
categories_[entry.index].emplace(AsCat(entry.fvalue)); categories_[entry.index].emplace(entry.fvalue);
} else { } else {
sketches_[entry.index].Push(entry.fvalue, w); sketches_[entry.index].Push(entry.fvalue, w);
} }
@ -352,10 +353,10 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
} }
} }
void AddCategories(std::set<bst_cat_t> const &categories, HistogramCuts *cuts) { void AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
auto &cut_values = cuts->cut_values_.HostVector(); auto &cut_values = cuts->cut_values_.HostVector();
for (auto const &v : categories) { for (auto const &v : categories) {
cut_values.push_back(v); cut_values.push_back(AsCat(v));
} }
} }
@ -410,6 +411,15 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back()); CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back());
cuts->cut_ptrs_.HostVector().push_back(cut_size); cuts->cut_ptrs_.HostVector().push_back(cut_size);
} }
if (has_categorical_) {
for (auto const &feat : categories_) {
if (std::any_of(feat.cbegin(), feat.cend(), InvalidCat)) {
InvalidCategory();
}
}
}
monitor_.Stop(__func__); monitor_.Stop(__func__);
} }
@ -457,7 +467,7 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
// second pass // second pass
if (IsCat(feature_types_, fidx)) { if (IsCat(feature_types_, fidx)) {
for (auto c : column) { for (auto c : column) {
categories_[fidx].emplace(AsCat(c.fvalue)); categories_[fidx].emplace(c.fvalue);
} }
} else { } else {
for (auto c : column) { for (auto c : column) {

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2021 by Contributors * Copyright 2014-2022 by XGBoost Contributors
* \file quantile.h * \file quantile.h
* \brief util to compute quantiles * \brief util to compute quantiles
* \author Tianqi Chen * \author Tianqi Chen
@ -706,13 +706,14 @@ template <typename WQSketch>
class SketchContainerImpl { class SketchContainerImpl {
protected: protected:
std::vector<WQSketch> sketches_; std::vector<WQSketch> sketches_;
std::vector<std::set<bst_cat_t>> categories_; std::vector<std::set<float>> categories_;
std::vector<FeatureType> const feature_types_; std::vector<FeatureType> const feature_types_;
std::vector<bst_row_t> columns_size_; std::vector<bst_row_t> columns_size_;
int32_t max_bins_; int32_t max_bins_;
bool use_group_ind_{false}; bool use_group_ind_{false};
int32_t n_threads_; int32_t n_threads_;
bool has_categorical_{false};
Monitor monitor_; Monitor monitor_;
public: public:

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2022 by XGBoost Contributors
*/ */
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ #ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ #define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
@ -303,8 +303,9 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
if (candidate.split.is_cat) { if (candidate.split.is_cat) {
std::vector<uint32_t> split_cats; std::vector<uint32_t> split_cats;
if (candidate.split.cat_bits.empty()) { if (candidate.split.cat_bits.empty()) {
CHECK_LT(candidate.split.split_value, std::numeric_limits<bst_cat_t>::max()) if (common::InvalidCat(candidate.split.split_value)) {
<< "Categorical feature value too large."; common::InvalidCategory();
}
auto cat = common::AsCat(candidate.split.split_value); auto cat = common::AsCat(candidate.split.split_value);
split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0); split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
LBitField32 cat_bits; LBitField32 cat_bits;

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2021 XGBoost contributors * Copyright 2017-2022 XGBoost contributors
*/ */
#include <thrust/copy.h> #include <thrust/copy.h>
#include <thrust/reduce.h> #include <thrust/reduce.h>
@ -572,11 +572,11 @@ struct GPUHistMakerDevice {
if (is_cat) { if (is_cat) {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max()) CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large."; << "Categorical feature value too large.";
auto cat = common::AsCat(candidate.split.fvalue); if (common::InvalidCat(candidate.split.fvalue)) {
if (common::InvalidCat(cat)) {
common::InvalidCategory(); common::InvalidCategory();
} }
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0); auto cat = common::AsCat(candidate.split.fvalue);
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
LBitField32 cats_bits(split_cats); LBitField32 cats_bits(split_cats);
cats_bits.Set(cat); cats_bits.Set(cat);
dh::CopyToD(split_cats, &node_categories); dh::CopyToD(split_cats, &node_categories);

View File

@ -60,20 +60,9 @@ class TestGPUUpdaters:
rounds = 4 rounds = 4
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
@pytest.mark.skipif(**tm.no_cupy())
def test_invalid_categorical(self): def test_invalid_categorical(self):
import cupy as cp self.cputest.run_invalid_category("gpu_hist")
rng = np.random.default_rng()
X = rng.normal(loc=0, scale=1, size=1000).reshape(100, 10)
y = rng.normal(loc=0, scale=1, size=100)
# Check is performe during sketching.
Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10)
with pytest.raises(ValueError):
xgb.train({"tree_method": "gpu_hist"}, Xy)
X, y = cp.array(X), cp.array(y)
with pytest.raises(ValueError):
Xy = xgb.DeviceQuantileDMatrix(X, y, feature_types=["c"] * 10)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@given(parameter_strategy, strategies.integers(1, 20), @given(parameter_strategy, strategies.integers(1, 20),

View File

@ -133,6 +133,41 @@ class TestTreeMethod:
w = [0, 0, 1, 0] w = [0, 0, 1, 0]
model.fit(X, y, sample_weight=w) model.fit(X, y, sample_weight=w)
def run_invalid_category(self, tree_method: str) -> None:
rng = np.random.default_rng()
# too large
X = rng.integers(low=0, high=4, size=1000).reshape(100, 10)
y = rng.normal(loc=0, scale=1, size=100)
X[13, 7] = np.iinfo(np.int32).max + 1
# Check is performed during sketching.
Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10)
with pytest.raises(ValueError):
xgb.train({"tree_method": tree_method}, Xy)
X[13, 7] = 16777216
Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10)
with pytest.raises(ValueError):
xgb.train({"tree_method": tree_method}, Xy)
# mixed positive and negative values
X = rng.normal(loc=0, scale=1, size=1000).reshape(100, 10)
y = rng.normal(loc=0, scale=1, size=100)
Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10)
with pytest.raises(ValueError):
xgb.train({"tree_method": tree_method}, Xy)
if tree_method == "gpu_hist":
import cupy as cp
X, y = cp.array(X), cp.array(y)
with pytest.raises(ValueError):
Xy = xgb.DeviceQuantileDMatrix(X, y, feature_types=["c"] * 10)
def test_invalid_category(self) -> None:
self.run_invalid_category("approx")
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
onehot, label = tm.make_categorical(rows, cols, cats, True) onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False) cat, _ = tm.make_categorical(rows, cols, cats, False)