Add test for invalid categorical data values. (#7380)
* Add test for invalid categorical data values. * Add check during sketching.
This commit is contained in:
@@ -42,9 +42,9 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t
|
||||
return !s_cats.Check(cat);
|
||||
}
|
||||
|
||||
inline void CheckCat(bst_cat_t cat) {
|
||||
CHECK_GE(cat, 0) << "Invalid categorical value detected. Categorical value "
|
||||
"should be non-negative.";
|
||||
inline void InvalidCategory() {
|
||||
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
|
||||
"should be non-negative.";
|
||||
}
|
||||
|
||||
struct IsCatOp {
|
||||
|
||||
@@ -580,6 +580,19 @@ void SketchContainer::AllReduce() {
|
||||
timer_.Stop(__func__);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct InvalidCat {
|
||||
Span<float const> values;
|
||||
Span<uint32_t const> ptrs;
|
||||
Span<FeatureType const> ft;
|
||||
|
||||
XGBOOST_DEVICE bool operator()(size_t i) {
|
||||
auto fidx = dh::SegmentId(ptrs, i);
|
||||
return IsCat(ft, fidx) && values[i] < 0;
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
@@ -669,6 +682,19 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
||||
assert(idx+1 < in_column.size());
|
||||
out_column[idx] = in_column[idx+1].value;
|
||||
});
|
||||
|
||||
if (has_categorical_) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
|
||||
auto it = thrust::make_counting_iterator(0ul);
|
||||
CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
|
||||
auto invalid =
|
||||
thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
|
||||
InvalidCat{out_cut_values, ptrs, d_ft});
|
||||
if (invalid) {
|
||||
InvalidCategory();
|
||||
}
|
||||
}
|
||||
timer_.Stop(__func__);
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
Reference in New Issue
Block a user