Add max_cat_threshold to GPU and handle missing cat values. (#8212)

This commit is contained in:
Jiaming Yuan
2022-09-07 00:57:51 +08:00
committed by GitHub
parent 441ffc017a
commit b5eb36f1af
10 changed files with 546 additions and 122 deletions

View File

@@ -601,13 +601,14 @@ struct GPUHistMakerDevice {
auto is_cat = candidate.split.is_cat;
if (is_cat) {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
std::vector<uint32_t> split_cats;
// should be set to nan in evaluation split.
CHECK(common::CheckNAN(candidate.split.fvalue));
std::vector<common::CatBitField::value_type> split_cats;
CHECK_GT(candidate.split.split_cats.Bits().size(), 0);
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
auto max_cat = candidate.split.MaxCat();
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex);
split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);
CHECK_LE(split_cats.size(), h_cats.size());
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
@@ -616,6 +617,7 @@ struct GPUHistMakerDevice {
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
} else {
CHECK(!common::CheckNAN(candidate.split.fvalue));
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.GetHess(),