Implement max_cat_threshold for CPU. (#7957)
This commit is contained in:
@@ -144,7 +144,8 @@ class HistEvaluator {
|
||||
|
||||
auto const &cut_ptr = cut.Ptrs();
|
||||
auto const &parent = snode_[nidx];
|
||||
bst_bin_t n_bins{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
|
||||
bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
|
||||
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
|
||||
|
||||
// statistics on both sides of split
|
||||
GradStats left_sum;
|
||||
@@ -152,7 +153,7 @@ class HistEvaluator {
|
||||
// best split so far
|
||||
SplitEntry best;
|
||||
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature);
|
||||
bst_bin_t ibegin, iend;
|
||||
bst_bin_t f_begin = cut_ptr[fidx];
|
||||
if (d_step > 0) {
|
||||
@@ -160,7 +161,7 @@ class HistEvaluator {
|
||||
iend = ibegin + n_bins - 1;
|
||||
} else {
|
||||
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
|
||||
iend = f_begin;
|
||||
iend = ibegin - n_bins + 1;
|
||||
}
|
||||
|
||||
bst_bin_t best_thresh{-1};
|
||||
@@ -177,7 +178,7 @@ class HistEvaluator {
|
||||
auto loss_chg =
|
||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||
parent.root_gain;
|
||||
// We don't have a numeric split point, nan hare is a dummy split.
|
||||
// We don't have a numeric split point, nan here is a dummy split.
|
||||
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
|
||||
left_sum, right_sum)) {
|
||||
best_thresh = i;
|
||||
@@ -186,10 +187,11 @@ class HistEvaluator {
|
||||
}
|
||||
|
||||
if (best_thresh != -1) {
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1);
|
||||
best.cat_bits = decltype(best.cat_bits)(n, 0);
|
||||
common::CatBitField cat_bits{best.cat_bits};
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend;
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin);
|
||||
CHECK_GT(partition, 0);
|
||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
|
||||
[&](size_t c) { cat_bits.Set(c); });
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user