Add max_cat_threshold to GPU and handle missing cat values. (#8212)
This commit is contained in:
@@ -143,8 +143,12 @@ class HistEvaluator {
|
||||
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
|
||||
|
||||
auto const &cut_ptr = cut.Ptrs();
|
||||
auto const &cut_val = cut.Values();
|
||||
auto const &parent = snode_[nidx];
|
||||
bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
|
||||
|
||||
bst_bin_t f_begin = cut_ptr[fidx];
|
||||
bst_bin_t f_end = cut_ptr[fidx + 1];
|
||||
bst_bin_t n_bins_feature{f_end - f_begin};
|
||||
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
|
||||
|
||||
// statistics on both sides of split
|
||||
@@ -153,19 +157,18 @@ class HistEvaluator {
|
||||
// best split so far
|
||||
SplitEntry best;
|
||||
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature);
|
||||
bst_bin_t ibegin, iend;
|
||||
bst_bin_t f_begin = cut_ptr[fidx];
|
||||
auto f_hist = hist.subspan(f_begin, n_bins_feature);
|
||||
bst_bin_t it_begin, it_end;
|
||||
if (d_step > 0) {
|
||||
ibegin = f_begin;
|
||||
iend = ibegin + n_bins - 1;
|
||||
it_begin = f_begin;
|
||||
it_end = it_begin + n_bins - 1;
|
||||
} else {
|
||||
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
|
||||
iend = ibegin - n_bins + 1;
|
||||
it_begin = f_end - 1;
|
||||
it_end = it_begin - n_bins + 1;
|
||||
}
|
||||
|
||||
bst_bin_t best_thresh{-1};
|
||||
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
|
||||
for (bst_bin_t i = it_begin; i != it_end; i += d_step) {
|
||||
auto j = i - f_begin; // index local to current feature
|
||||
if (d_step == 1) {
|
||||
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
|
||||
@@ -187,13 +190,15 @@ class HistEvaluator {
|
||||
}
|
||||
|
||||
if (best_thresh != -1) {
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1);
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
|
||||
best.cat_bits = decltype(best.cat_bits)(n, 0);
|
||||
common::CatBitField cat_bits{best.cat_bits};
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin);
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
|
||||
CHECK_GT(partition, 0);
|
||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
|
||||
[&](size_t c) { cat_bits.Set(c); });
|
||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](size_t c) {
|
||||
auto cat = cut_val[c + f_begin];
|
||||
cat_bits.Set(cat);
|
||||
});
|
||||
}
|
||||
|
||||
p_best->Update(best);
|
||||
|
||||
Reference in New Issue
Block a user