More categorical tests and disable shap sparse test. (#6219)

* Fix tree load with 32 category.
This commit is contained in:
Jiaming Yuan
2020-10-10 16:12:37 +08:00
committed by GitHub
parent c991eb612d
commit b5b24354b8
5 changed files with 120 additions and 27 deletions

View File

@@ -186,7 +186,9 @@ Json& JsonObject::operator[](int ind) {
}
bool JsonObject::operator==(Value const& rhs) const {
if (!IsA<JsonObject>(&rhs)) { return false; }
if (!IsA<JsonObject>(&rhs)) {
return false;
}
return object_ == Cast<JsonObject const>(&rhs)->GetObject();
}
@@ -275,10 +277,14 @@ Json& JsonNumber::operator[](int ind) {
bool JsonNumber::operator==(Value const& rhs) const {
if (!IsA<JsonNumber>(&rhs)) { return false; }
auto r_num = Cast<JsonNumber const>(&rhs)->GetNumber();
if (std::isinf(number_)) {
return std::isinf(Cast<JsonNumber const>(&rhs)->GetNumber());
return std::isinf(r_num);
}
return std::abs(number_ - Cast<JsonNumber const>(&rhs)->GetNumber()) < kRtEps;
if (std::isnan(number_)) {
return std::isnan(r_num);
}
return number_ - r_num == 0;
}
Value & JsonNumber::operator=(Value const &rhs) {

View File

@@ -792,16 +792,17 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
auto j_begin = get<Integer const>(categories_segments[cnt]);
auto j_end = get<Integer const>(categories_sizes[cnt]) + j_begin;
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
CHECK_NE(j_end - j_begin, 0) << nidx;
for (auto j = j_begin; j < j_end; ++j) {
auto const &category = get<Integer const>(categories[j]);
auto cat = common::AsCat(category);
max_cat = std::max(max_cat, cat);
}
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
? 0
: common::KCatBitField::ComputeStorageSize(max_cat);
size = size == 0 ? 1 : size;
// Have at least 1 category in split.
CHECK_NE(std::numeric_limits<bst_cat_t>::min(), max_cat);
size_t n_cats = max_cat + 1; // cat 0
size_t size = common::KCatBitField::ComputeStorageSize(n_cats);
std::vector<uint32_t> cat_bits_storage(size, 0);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto j = j_begin; j < j_end; ++j) {