Add max_cat_threshold to GPU and handle missing cat values. (#8212)

This commit is contained in:
Jiaming Yuan
2022-09-07 00:57:51 +08:00
committed by GitHub
parent 441ffc017a
commit b5eb36f1af
10 changed files with 546 additions and 122 deletions

View File

@@ -143,8 +143,12 @@ class HistEvaluator {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
auto const &cut_ptr = cut.Ptrs();
auto const &cut_val = cut.Values();
auto const &parent = snode_[nidx];
bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
bst_bin_t f_begin = cut_ptr[fidx];
bst_bin_t f_end = cut_ptr[fidx + 1];
bst_bin_t n_bins_feature{f_end - f_begin};
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
// statistics on both sides of split
@@ -153,19 +157,18 @@ class HistEvaluator {
// best split so far
SplitEntry best;
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature);
bst_bin_t ibegin, iend;
bst_bin_t f_begin = cut_ptr[fidx];
auto f_hist = hist.subspan(f_begin, n_bins_feature);
bst_bin_t it_begin, it_end;
if (d_step > 0) {
ibegin = f_begin;
iend = ibegin + n_bins - 1;
it_begin = f_begin;
it_end = it_begin + n_bins - 1;
} else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = ibegin - n_bins + 1;
it_begin = f_end - 1;
it_end = it_begin - n_bins + 1;
}
bst_bin_t best_thresh{-1};
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
for (bst_bin_t i = it_begin; i != it_end; i += d_step) {
auto j = i - f_begin; // index local to current feature
if (d_step == 1) {
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
@@ -187,13 +190,15 @@ class HistEvaluator {
}
if (best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1);
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
best.cat_bits = decltype(best.cat_bits)(n, 0);
common::CatBitField cat_bits{best.cat_bits};
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin);
bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
CHECK_GT(partition, 0);
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
[&](size_t c) { cat_bits.Set(c); });
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](size_t c) {
auto cat = cut_val[c + f_begin];
cat_bits.Set(cat);
});
}
p_best->Update(best);