Implement max_cat_threshold for CPU. (#7957)

This commit is contained in:
Jiaming Yuan
2022-06-04 11:02:46 +08:00
committed by GitHub
parent 78694405a6
commit b90c6d25e8
8 changed files with 177 additions and 20 deletions

View File

@@ -54,7 +54,7 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
*/
template <bool validate = true>
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
CLBitField32 const s_cats(cats);
KCatBitField const s_cats(cats);
// FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {

View File

@@ -144,7 +144,8 @@ class HistEvaluator {
auto const &cut_ptr = cut.Ptrs();
auto const &parent = snode_[nidx];
bst_bin_t n_bins{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
// statistics on both sides of split
GradStats left_sum;
@@ -152,7 +153,7 @@ class HistEvaluator {
// best split so far
SplitEntry best;
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature);
bst_bin_t ibegin, iend;
bst_bin_t f_begin = cut_ptr[fidx];
if (d_step > 0) {
@@ -160,7 +161,7 @@ class HistEvaluator {
iend = ibegin + n_bins - 1;
} else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = f_begin;
iend = ibegin - n_bins + 1;
}
bst_bin_t best_thresh{-1};
@@ -177,7 +178,7 @@ class HistEvaluator {
auto loss_chg =
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain;
// We don't have a numeric split point, nan hare is a dummy split.
// We don't have a numeric split point, nan here is a dummy split.
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
left_sum, right_sum)) {
best_thresh = i;
@@ -186,10 +187,11 @@ class HistEvaluator {
}
if (best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1);
best.cat_bits = decltype(best.cat_bits)(n, 0);
common::CatBitField cat_bits{best.cat_bits};
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend;
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin);
CHECK_GT(partition, 0);
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
[&](size_t c) { cat_bits.Set(c); });
}

View File

@@ -40,6 +40,8 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
uint32_t max_cat_to_onehot{4};
bst_bin_t max_cat_threshold{64};
//----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child
float min_child_weight;
@@ -113,6 +115,12 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
.set_default(4)
.set_lower_bound(1)
.describe("Maximum number of categories to use one-hot encoding based split.");
DMLC_DECLARE_FIELD(max_cat_threshold)
.set_default(64)
.set_lower_bound(1)
.describe(
"Maximum number of categories considered for split. Used only by partition-based"
"splits.");
DMLC_DECLARE_FIELD(min_child_weight)
.set_lower_bound(0.0f)
.set_default(1.0f)