Clarify the behavior of invalid categorical value handling. (#7529)

This commit is contained in:
Jiaming Yuan
2022-01-13 16:11:52 +08:00
committed by GitHub
parent 20c0d60ac7
commit e5e47c3c99
7 changed files with 88 additions and 25 deletions

View File

@@ -5,6 +5,8 @@
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
#define XGBOOST_COMMON_CATEGORICAL_H_
#include <limits>
#include "bitfield.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
@@ -30,22 +32,30 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
}
inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
}
/* \brief Whether should it traverse to left branch of a tree.
*
* For one hot split, go to left if it's NOT the matching category.
*/
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
auto pos = CLBitField32::ToBitPos(cat);
if (pos.int_pos >= cats.size()) {
return true;
}
template <bool validate = true>
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
CLBitField32 const s_cats(cats);
return !s_cats.Check(cat);
// FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
}
return !s_cats.Check(AsCat(cat));
}
inline void InvalidCategory() {
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
"should be non-negative.";
"should be non-negative, less than maximum size of int32 and less than total "
"number of categories in training data.";
}
/*!
@@ -58,9 +68,7 @@ inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task)
}
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
};
using CatBitField = LBitField32;

View File

@@ -581,14 +581,14 @@ void SketchContainer::AllReduce() {
}
namespace {
struct InvalidCat {
struct InvalidCatOp {
Span<float const> values;
Span<uint32_t const> ptrs;
Span<FeatureType const> ft;
XGBOOST_DEVICE bool operator()(size_t i) {
auto fidx = dh::SegmentId(ptrs, i);
return IsCat(ft, fidx) && values[i] < 0;
return IsCat(ft, fidx) && InvalidCat(values[i]);
}
};
} // anonymous namespace
@@ -687,10 +687,10 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
auto it = thrust::make_counting_iterator(0ul);
CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
auto invalid =
thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
InvalidCat{out_cut_values, ptrs, d_ft});
auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
InvalidCatOp{out_cut_values, ptrs, d_ft});
if (invalid) {
InvalidCategory();
}

View File

@@ -9,16 +9,16 @@
namespace xgboost {
namespace predictor {
template <bool has_missing, bool has_categorical>
inline XGBOOST_DEVICE bst_node_t
GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue,
bool is_missing, RegTree::CategoricalSplitMatrix const &cats) {
inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
float fvalue, bool is_missing,
RegTree::CategoricalSplitMatrix const &cats) {
if (has_missing && is_missing) {
return node.DefaultChild();
} else {
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg,
cats.node_ptr[nid].size);
return Decision(node_categories, common::AsCat(fvalue))
auto node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
? node.LeftChild()
: node.RightChild();
} else {

View File

@@ -95,7 +95,7 @@ class ApproxRowPartitioner {
auto node_cats = categories.subspan(segment.beg, segment.size);
bool go_left = true;
if (is_cat) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
} else {
go_left = cut_value <= candidate.split.split_value;
}

View File

@@ -396,7 +396,7 @@ struct GPUHistMakerDevice {
} else {
bool go_left = true;
if (split_type == FeatureType::kCategorical) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision<false>(node_cats, cut_value, split_node.DefaultLeft());
} else {
go_left = cut_value <= split_node.SplitCond();
}
@@ -474,7 +474,7 @@ struct GPUHistMakerDevice {
auto node_cats =
categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision(node_cats, common::AsCat(element));
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
} else {
go_left = element <= node.SplitCond();
}
@@ -573,7 +573,7 @@ struct GPUHistMakerDevice {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
auto cat = common::AsCat(candidate.split.fvalue);
if (cat < 0) {
if (common::InvalidCat(cat)) {
common::InvalidCategory();
}
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);