More categorical tests and disable shap sparse test. (#6219)
* Fix tree load with 32 category.
This commit is contained in:
@@ -186,7 +186,9 @@ Json& JsonObject::operator[](int ind) {
|
||||
}
|
||||
|
||||
bool JsonObject::operator==(Value const& rhs) const {
|
||||
if (!IsA<JsonObject>(&rhs)) { return false; }
|
||||
if (!IsA<JsonObject>(&rhs)) {
|
||||
return false;
|
||||
}
|
||||
return object_ == Cast<JsonObject const>(&rhs)->GetObject();
|
||||
}
|
||||
|
||||
@@ -275,10 +277,14 @@ Json& JsonNumber::operator[](int ind) {
|
||||
|
||||
bool JsonNumber::operator==(Value const& rhs) const {
|
||||
if (!IsA<JsonNumber>(&rhs)) { return false; }
|
||||
auto r_num = Cast<JsonNumber const>(&rhs)->GetNumber();
|
||||
if (std::isinf(number_)) {
|
||||
return std::isinf(Cast<JsonNumber const>(&rhs)->GetNumber());
|
||||
return std::isinf(r_num);
|
||||
}
|
||||
return std::abs(number_ - Cast<JsonNumber const>(&rhs)->GetNumber()) < kRtEps;
|
||||
if (std::isnan(number_)) {
|
||||
return std::isnan(r_num);
|
||||
}
|
||||
return number_ - r_num == 0;
|
||||
}
|
||||
|
||||
Value & JsonNumber::operator=(Value const &rhs) {
|
||||
|
||||
@@ -792,16 +792,17 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
auto j_begin = get<Integer const>(categories_segments[cnt]);
|
||||
auto j_end = get<Integer const>(categories_sizes[cnt]) + j_begin;
|
||||
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
|
||||
CHECK_NE(j_end - j_begin, 0) << nidx;
|
||||
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
auto const &category = get<Integer const>(categories[j]);
|
||||
auto cat = common::AsCat(category);
|
||||
max_cat = std::max(max_cat, cat);
|
||||
}
|
||||
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
|
||||
? 0
|
||||
: common::KCatBitField::ComputeStorageSize(max_cat);
|
||||
size = size == 0 ? 1 : size;
|
||||
// Have at least 1 category in split.
|
||||
CHECK_NE(std::numeric_limits<bst_cat_t>::min(), max_cat);
|
||||
size_t n_cats = max_cat + 1; // cat 0
|
||||
size_t size = common::KCatBitField::ComputeStorageSize(n_cats);
|
||||
std::vector<uint32_t> cat_bits_storage(size, 0);
|
||||
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
|
||||
Reference in New Issue
Block a user