Clarify the behavior of invalid categorical value handling. (#7529)
This commit is contained in:
parent
20c0d60ac7
commit
e5e47c3c99
@ -108,6 +108,18 @@ feature it's specified as ``"c"``. The Dask module in XGBoost has the same inte
|
|||||||
:class:`dask.Array <dask.Array>` can also be used as categorical data.
|
:class:`dask.Array <dask.Array>` can also be used as categorical data.
|
||||||
|
|
||||||
|
|
||||||
|
*************
|
||||||
|
Miscellaneous
|
||||||
|
*************
|
||||||
|
|
||||||
|
By default, XGBoost assumes input categories are integers starting from 0 till the number
|
||||||
|
of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid
|
||||||
|
values due to mistakes or missing values. It can be negative value, floating point value
|
||||||
|
that can not be represented by 32-bit integer, or values that are larger than actual
|
||||||
|
number of unique categories. During training this is validated but for prediction it's
|
||||||
|
treated as the same as missing value for performance reasons. Lastly, missing values are
|
||||||
|
treated as the same as numerical features.
|
||||||
|
|
||||||
**********
|
**********
|
||||||
Next Steps
|
Next Steps
|
||||||
**********
|
**********
|
||||||
|
|||||||
@ -5,6 +5,8 @@
|
|||||||
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
|
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
|
||||||
#define XGBOOST_COMMON_CATEGORICAL_H_
|
#define XGBOOST_COMMON_CATEGORICAL_H_
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "bitfield.h"
|
#include "bitfield.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -30,22 +32,30 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
|
|||||||
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
|
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline XGBOOST_DEVICE bool InvalidCat(float cat) {
|
||||||
|
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
|
||||||
|
}
|
||||||
|
|
||||||
/* \brief Whether should it traverse to left branch of a tree.
|
/* \brief Whether should it traverse to left branch of a tree.
|
||||||
*
|
*
|
||||||
* For one hot split, go to left if it's NOT the matching category.
|
* For one hot split, go to left if it's NOT the matching category.
|
||||||
*/
|
*/
|
||||||
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
|
template <bool validate = true>
|
||||||
auto pos = CLBitField32::ToBitPos(cat);
|
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
|
||||||
if (pos.int_pos >= cats.size()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
CLBitField32 const s_cats(cats);
|
CLBitField32 const s_cats(cats);
|
||||||
return !s_cats.Check(cat);
|
// FIXME: Size() is not accurate since it represents the size of bit set instead of
|
||||||
|
// actual number of categories.
|
||||||
|
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
|
||||||
|
return dft_left;
|
||||||
|
}
|
||||||
|
return !s_cats.Check(AsCat(cat));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void InvalidCategory() {
|
inline void InvalidCategory() {
|
||||||
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
|
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
|
||||||
"should be non-negative.";
|
"should be non-negative, less than maximum size of int32 and less than total "
|
||||||
|
"number of categories in training data.";
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -58,9 +68,7 @@ inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task)
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct IsCatOp {
|
struct IsCatOp {
|
||||||
XGBOOST_DEVICE bool operator()(FeatureType ft) {
|
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
|
||||||
return ft == FeatureType::kCategorical;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
using CatBitField = LBitField32;
|
using CatBitField = LBitField32;
|
||||||
|
|||||||
@ -581,14 +581,14 @@ void SketchContainer::AllReduce() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct InvalidCat {
|
struct InvalidCatOp {
|
||||||
Span<float const> values;
|
Span<float const> values;
|
||||||
Span<uint32_t const> ptrs;
|
Span<uint32_t const> ptrs;
|
||||||
Span<FeatureType const> ft;
|
Span<FeatureType const> ft;
|
||||||
|
|
||||||
XGBOOST_DEVICE bool operator()(size_t i) {
|
XGBOOST_DEVICE bool operator()(size_t i) {
|
||||||
auto fidx = dh::SegmentId(ptrs, i);
|
auto fidx = dh::SegmentId(ptrs, i);
|
||||||
return IsCat(ft, fidx) && values[i] < 0;
|
return IsCat(ft, fidx) && InvalidCat(values[i]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
@ -687,10 +687,10 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
|||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
|
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
|
||||||
auto it = thrust::make_counting_iterator(0ul);
|
auto it = thrust::make_counting_iterator(0ul);
|
||||||
|
|
||||||
CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
|
CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
|
||||||
auto invalid =
|
auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
|
||||||
thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
|
InvalidCatOp{out_cut_values, ptrs, d_ft});
|
||||||
InvalidCat{out_cut_values, ptrs, d_ft});
|
|
||||||
if (invalid) {
|
if (invalid) {
|
||||||
InvalidCategory();
|
InvalidCategory();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,16 +9,16 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace predictor {
|
namespace predictor {
|
||||||
template <bool has_missing, bool has_categorical>
|
template <bool has_missing, bool has_categorical>
|
||||||
inline XGBOOST_DEVICE bst_node_t
|
inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
|
||||||
GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue,
|
float fvalue, bool is_missing,
|
||||||
bool is_missing, RegTree::CategoricalSplitMatrix const &cats) {
|
RegTree::CategoricalSplitMatrix const &cats) {
|
||||||
if (has_missing && is_missing) {
|
if (has_missing && is_missing) {
|
||||||
return node.DefaultChild();
|
return node.DefaultChild();
|
||||||
} else {
|
} else {
|
||||||
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
||||||
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg,
|
auto node_categories =
|
||||||
cats.node_ptr[nid].size);
|
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
|
||||||
return Decision(node_categories, common::AsCat(fvalue))
|
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
|
||||||
? node.LeftChild()
|
? node.LeftChild()
|
||||||
: node.RightChild();
|
: node.RightChild();
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -95,7 +95,7 @@ class ApproxRowPartitioner {
|
|||||||
auto node_cats = categories.subspan(segment.beg, segment.size);
|
auto node_cats = categories.subspan(segment.beg, segment.size);
|
||||||
bool go_left = true;
|
bool go_left = true;
|
||||||
if (is_cat) {
|
if (is_cat) {
|
||||||
go_left = common::Decision(node_cats, common::AsCat(cut_value));
|
go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
|
||||||
} else {
|
} else {
|
||||||
go_left = cut_value <= candidate.split.split_value;
|
go_left = cut_value <= candidate.split.split_value;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -396,7 +396,7 @@ struct GPUHistMakerDevice {
|
|||||||
} else {
|
} else {
|
||||||
bool go_left = true;
|
bool go_left = true;
|
||||||
if (split_type == FeatureType::kCategorical) {
|
if (split_type == FeatureType::kCategorical) {
|
||||||
go_left = common::Decision(node_cats, common::AsCat(cut_value));
|
go_left = common::Decision<false>(node_cats, cut_value, split_node.DefaultLeft());
|
||||||
} else {
|
} else {
|
||||||
go_left = cut_value <= split_node.SplitCond();
|
go_left = cut_value <= split_node.SplitCond();
|
||||||
}
|
}
|
||||||
@ -474,7 +474,7 @@ struct GPUHistMakerDevice {
|
|||||||
auto node_cats =
|
auto node_cats =
|
||||||
categories.subspan(categories_segments[position].beg,
|
categories.subspan(categories_segments[position].beg,
|
||||||
categories_segments[position].size);
|
categories_segments[position].size);
|
||||||
go_left = common::Decision(node_cats, common::AsCat(element));
|
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
|
||||||
} else {
|
} else {
|
||||||
go_left = element <= node.SplitCond();
|
go_left = element <= node.SplitCond();
|
||||||
}
|
}
|
||||||
@ -573,7 +573,7 @@ struct GPUHistMakerDevice {
|
|||||||
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
|
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
|
||||||
<< "Categorical feature value too large.";
|
<< "Categorical feature value too large.";
|
||||||
auto cat = common::AsCat(candidate.split.fvalue);
|
auto cat = common::AsCat(candidate.split.fvalue);
|
||||||
if (cat < 0) {
|
if (common::InvalidCat(cat)) {
|
||||||
common::InvalidCategory();
|
common::InvalidCategory();
|
||||||
}
|
}
|
||||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
|
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
|
||||||
|
|||||||
43
tests/cpp/common/test_categorical.cc
Normal file
43
tests/cpp/common/test_categorical.cc
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "../../../src/common/categorical.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
TEST(Categorical, Decision) {
|
||||||
|
// inf
|
||||||
|
float a = std::numeric_limits<float>::infinity();
|
||||||
|
|
||||||
|
ASSERT_TRUE(common::InvalidCat(a));
|
||||||
|
std::vector<uint32_t> cats(256, 0);
|
||||||
|
ASSERT_TRUE(Decision(cats, a, true));
|
||||||
|
|
||||||
|
// larger than size
|
||||||
|
a = 256;
|
||||||
|
ASSERT_TRUE(Decision(cats, a, true));
|
||||||
|
|
||||||
|
// negative
|
||||||
|
a = -1;
|
||||||
|
ASSERT_TRUE(Decision(cats, a, true));
|
||||||
|
|
||||||
|
CatBitField bits{cats};
|
||||||
|
bits.Set(0);
|
||||||
|
a = -0.5;
|
||||||
|
ASSERT_TRUE(Decision(cats, a, true));
|
||||||
|
|
||||||
|
// round toward 0
|
||||||
|
a = 0.5;
|
||||||
|
ASSERT_FALSE(Decision(cats, a, true));
|
||||||
|
|
||||||
|
// valid
|
||||||
|
a = 13;
|
||||||
|
bits.Set(a);
|
||||||
|
ASSERT_FALSE(Decision(bits.Bits(), a, true));
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user