Validate out of range categorical value. (#7576)
* Use float in CPU categorical set to preserve the input value. * Check out of range values.
This commit is contained in:
parent
d6ea5cc1ed
commit
deab0e32ba
@ -114,11 +114,11 @@ Miscellaneous
|
||||
|
||||
By default, XGBoost assumes input categories are integers starting from 0 till the number
|
||||
of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid
|
||||
values due to mistakes or missing values. It can be negative value, floating point value
|
||||
that can not be represented by 32-bit integer, or values that are larger than actual
|
||||
number of unique categories. During training this is validated but for prediction it's
|
||||
treated as the same as missing value for performance reasons. Lastly, missing values are
|
||||
treated as the same as numerical features.
|
||||
values due to mistakes or missing values. It can be negative value, integer values that
|
||||
can not be accurately represented by 32-bit floating point, or values that are larger than
|
||||
actual number of unique categories. During training this is validated but for prediction
|
||||
it's treated as the same as missing value for performance reasons. Lastly, missing values
|
||||
are treated as the same as numerical features (using the learned split direction).
|
||||
|
||||
**********
|
||||
Next Steps
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020-2021 by XGBoost Contributors
|
||||
* Copyright 2020-2022 by XGBoost Contributors
|
||||
* \file categorical.h
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
|
||||
@ -32,9 +32,17 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
|
||||
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
|
||||
}
|
||||
|
||||
constexpr inline bst_cat_t OutOfRangeCat() {
|
||||
// See the round trip assert in `InvalidCat`.
|
||||
return static_cast<bst_cat_t>(16777217) - static_cast<bst_cat_t>(1);
|
||||
}
|
||||
|
||||
inline XGBOOST_DEVICE bool InvalidCat(float cat) {
|
||||
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
|
||||
constexpr auto kMaxCat = OutOfRangeCat();
|
||||
static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat)) == kMaxCat, "");
|
||||
static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat + 1)) != kMaxCat + 1, "");
|
||||
static_assert(static_cast<float>(kMaxCat + 1) == kMaxCat, "");
|
||||
return cat < 0 || cat >= kMaxCat;
|
||||
}
|
||||
|
||||
/* \brief Whether should it traverse to left branch of a tree.
|
||||
@ -53,9 +61,13 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat
|
||||
}
|
||||
|
||||
inline void InvalidCategory() {
|
||||
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
|
||||
"should be non-negative, less than maximum size of int32 and less than total "
|
||||
"number of categories in training data.";
|
||||
// OutOfRangeCat() can be accurately represented, but everything after it will be
|
||||
// rounded toward it, so we use >= for comparison check. As a result, we require input
|
||||
// values to be less than this last representable value.
|
||||
auto str = std::to_string(OutOfRangeCat());
|
||||
LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, "
|
||||
"less than total umber of categories in training data and less than " +
|
||||
str;
|
||||
}
|
||||
|
||||
/*!
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020-2021 by XGBoost Contributors
|
||||
* Copyright 2020-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
@ -27,6 +27,7 @@ SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> column
|
||||
sketches_.resize(columns_size_.size());
|
||||
CHECK_GE(n_threads_, 1);
|
||||
categories_.resize(columns_size_.size());
|
||||
has_categorical_ = std::any_of(feature_types_.cbegin(), feature_types_.cend(), IsCatOp{});
|
||||
}
|
||||
|
||||
template <typename WQSketch>
|
||||
@ -187,7 +188,7 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
|
||||
if (is_dense) {
|
||||
for (size_t ii = begin; ii < end; ii++) {
|
||||
if (IsCat(feature_types_, ii)) {
|
||||
categories_[ii].emplace(AsCat(p_inst[ii].fvalue));
|
||||
categories_[ii].emplace(p_inst[ii].fvalue);
|
||||
} else {
|
||||
sketches_[ii].Push(p_inst[ii].fvalue, w);
|
||||
}
|
||||
@ -197,7 +198,7 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
|
||||
auto const& entry = p_inst[i];
|
||||
if (entry.index >= begin && entry.index < end) {
|
||||
if (IsCat(feature_types_, entry.index)) {
|
||||
categories_[entry.index].emplace(AsCat(entry.fvalue));
|
||||
categories_[entry.index].emplace(entry.fvalue);
|
||||
} else {
|
||||
sketches_[entry.index].Push(entry.fvalue, w);
|
||||
}
|
||||
@ -352,10 +353,10 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
|
||||
}
|
||||
}
|
||||
|
||||
void AddCategories(std::set<bst_cat_t> const &categories, HistogramCuts *cuts) {
|
||||
void AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
|
||||
auto &cut_values = cuts->cut_values_.HostVector();
|
||||
for (auto const &v : categories) {
|
||||
cut_values.push_back(v);
|
||||
cut_values.push_back(AsCat(v));
|
||||
}
|
||||
}
|
||||
|
||||
@ -410,6 +411,15 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
|
||||
CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back());
|
||||
cuts->cut_ptrs_.HostVector().push_back(cut_size);
|
||||
}
|
||||
|
||||
if (has_categorical_) {
|
||||
for (auto const &feat : categories_) {
|
||||
if (std::any_of(feat.cbegin(), feat.cend(), InvalidCat)) {
|
||||
InvalidCategory();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
monitor_.Stop(__func__);
|
||||
}
|
||||
|
||||
@ -457,7 +467,7 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
|
||||
// second pass
|
||||
if (IsCat(feature_types_, fidx)) {
|
||||
for (auto c : column) {
|
||||
categories_[fidx].emplace(AsCat(c.fvalue));
|
||||
categories_[fidx].emplace(c.fvalue);
|
||||
}
|
||||
} else {
|
||||
for (auto c : column) {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
* \file quantile.h
|
||||
* \brief util to compute quantiles
|
||||
* \author Tianqi Chen
|
||||
@ -706,13 +706,14 @@ template <typename WQSketch>
|
||||
class SketchContainerImpl {
|
||||
protected:
|
||||
std::vector<WQSketch> sketches_;
|
||||
std::vector<std::set<bst_cat_t>> categories_;
|
||||
std::vector<std::set<float>> categories_;
|
||||
std::vector<FeatureType> const feature_types_;
|
||||
|
||||
std::vector<bst_row_t> columns_size_;
|
||||
int32_t max_bins_;
|
||||
bool use_group_ind_{false};
|
||||
int32_t n_threads_;
|
||||
bool has_categorical_{false};
|
||||
Monitor monitor_;
|
||||
|
||||
public:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
* Copyright 2021-2022 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||
@ -303,8 +303,9 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
if (candidate.split.is_cat) {
|
||||
std::vector<uint32_t> split_cats;
|
||||
if (candidate.split.cat_bits.empty()) {
|
||||
CHECK_LT(candidate.split.split_value, std::numeric_limits<bst_cat_t>::max())
|
||||
<< "Categorical feature value too large.";
|
||||
if (common::InvalidCat(candidate.split.split_value)) {
|
||||
common::InvalidCategory();
|
||||
}
|
||||
auto cat = common::AsCat(candidate.split.split_value);
|
||||
split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
|
||||
LBitField32 cat_bits;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 XGBoost contributors
|
||||
* Copyright 2017-2022 XGBoost contributors
|
||||
*/
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/reduce.h>
|
||||
@ -572,11 +572,11 @@ struct GPUHistMakerDevice {
|
||||
if (is_cat) {
|
||||
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
|
||||
<< "Categorical feature value too large.";
|
||||
auto cat = common::AsCat(candidate.split.fvalue);
|
||||
if (common::InvalidCat(cat)) {
|
||||
if (common::InvalidCat(candidate.split.fvalue)) {
|
||||
common::InvalidCategory();
|
||||
}
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
|
||||
auto cat = common::AsCat(candidate.split.fvalue);
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
|
||||
LBitField32 cats_bits(split_cats);
|
||||
cats_bits.Set(cat);
|
||||
dh::CopyToD(split_cats, &node_categories);
|
||||
|
||||
@ -60,20 +60,9 @@ class TestGPUUpdaters:
|
||||
rounds = 4
|
||||
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_invalid_categorical(self):
|
||||
import cupy as cp
|
||||
rng = np.random.default_rng()
|
||||
X = rng.normal(loc=0, scale=1, size=1000).reshape(100, 10)
|
||||
y = rng.normal(loc=0, scale=1, size=100)
|
||||
|
||||
# Check is performe during sketching.
|
||||
Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10)
|
||||
with pytest.raises(ValueError):
|
||||
xgb.train({"tree_method": "gpu_hist"}, Xy)
|
||||
|
||||
X, y = cp.array(X), cp.array(y)
|
||||
with pytest.raises(ValueError):
|
||||
Xy = xgb.DeviceQuantileDMatrix(X, y, feature_types=["c"] * 10)
|
||||
self.cputest.run_invalid_category("gpu_hist")
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@given(parameter_strategy, strategies.integers(1, 20),
|
||||
|
||||
@ -133,6 +133,41 @@ class TestTreeMethod:
|
||||
w = [0, 0, 1, 0]
|
||||
model.fit(X, y, sample_weight=w)
|
||||
|
||||
def run_invalid_category(self, tree_method: str) -> None:
|
||||
rng = np.random.default_rng()
|
||||
# too large
|
||||
X = rng.integers(low=0, high=4, size=1000).reshape(100, 10)
|
||||
y = rng.normal(loc=0, scale=1, size=100)
|
||||
X[13, 7] = np.iinfo(np.int32).max + 1
|
||||
|
||||
# Check is performed during sketching.
|
||||
Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10)
|
||||
with pytest.raises(ValueError):
|
||||
xgb.train({"tree_method": tree_method}, Xy)
|
||||
|
||||
X[13, 7] = 16777216
|
||||
Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10)
|
||||
with pytest.raises(ValueError):
|
||||
xgb.train({"tree_method": tree_method}, Xy)
|
||||
|
||||
# mixed positive and negative values
|
||||
X = rng.normal(loc=0, scale=1, size=1000).reshape(100, 10)
|
||||
y = rng.normal(loc=0, scale=1, size=100)
|
||||
|
||||
Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10)
|
||||
with pytest.raises(ValueError):
|
||||
xgb.train({"tree_method": tree_method}, Xy)
|
||||
|
||||
if tree_method == "gpu_hist":
|
||||
import cupy as cp
|
||||
|
||||
X, y = cp.array(X), cp.array(y)
|
||||
with pytest.raises(ValueError):
|
||||
Xy = xgb.DeviceQuantileDMatrix(X, y, feature_types=["c"] * 10)
|
||||
|
||||
def test_invalid_category(self) -> None:
|
||||
self.run_invalid_category("approx")
|
||||
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user