Implement max_cat_threshold for CPU. (#7957)
This commit is contained in:
@@ -54,7 +54,7 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
|
||||
*/
|
||||
template <bool validate = true>
|
||||
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
|
||||
CLBitField32 const s_cats(cats);
|
||||
KCatBitField const s_cats(cats);
|
||||
// FIXME: Size() is not accurate since it represents the size of bit set instead of
|
||||
// actual number of categories.
|
||||
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
|
||||
|
||||
@@ -144,7 +144,8 @@ class HistEvaluator {
|
||||
|
||||
auto const &cut_ptr = cut.Ptrs();
|
||||
auto const &parent = snode_[nidx];
|
||||
bst_bin_t n_bins{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
|
||||
bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
|
||||
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
|
||||
|
||||
// statistics on both sides of split
|
||||
GradStats left_sum;
|
||||
@@ -152,7 +153,7 @@ class HistEvaluator {
|
||||
// best split so far
|
||||
SplitEntry best;
|
||||
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature);
|
||||
bst_bin_t ibegin, iend;
|
||||
bst_bin_t f_begin = cut_ptr[fidx];
|
||||
if (d_step > 0) {
|
||||
@@ -160,7 +161,7 @@ class HistEvaluator {
|
||||
iend = ibegin + n_bins - 1;
|
||||
} else {
|
||||
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
|
||||
iend = f_begin;
|
||||
iend = ibegin - n_bins + 1;
|
||||
}
|
||||
|
||||
bst_bin_t best_thresh{-1};
|
||||
@@ -177,7 +178,7 @@ class HistEvaluator {
|
||||
auto loss_chg =
|
||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||
parent.root_gain;
|
||||
// We don't have a numeric split point, nan hare is a dummy split.
|
||||
// We don't have a numeric split point, nan here is a dummy split.
|
||||
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
|
||||
left_sum, right_sum)) {
|
||||
best_thresh = i;
|
||||
@@ -186,10 +187,11 @@ class HistEvaluator {
|
||||
}
|
||||
|
||||
if (best_thresh != -1) {
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1);
|
||||
best.cat_bits = decltype(best.cat_bits)(n, 0);
|
||||
common::CatBitField cat_bits{best.cat_bits};
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend;
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin);
|
||||
CHECK_GT(partition, 0);
|
||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
|
||||
[&](size_t c) { cat_bits.Set(c); });
|
||||
}
|
||||
|
||||
@@ -40,6 +40,8 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
||||
|
||||
uint32_t max_cat_to_onehot{4};
|
||||
|
||||
bst_bin_t max_cat_threshold{64};
|
||||
|
||||
//----- the rest parameters are less important ----
|
||||
// minimum amount of hessian(weight) allowed in a child
|
||||
float min_child_weight;
|
||||
@@ -113,6 +115,12 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
||||
.set_default(4)
|
||||
.set_lower_bound(1)
|
||||
.describe("Maximum number of categories to use one-hot encoding based split.");
|
||||
DMLC_DECLARE_FIELD(max_cat_threshold)
|
||||
.set_default(64)
|
||||
.set_lower_bound(1)
|
||||
.describe(
|
||||
"Maximum number of categories considered for split. Used only by partition-based"
|
||||
"splits.");
|
||||
DMLC_DECLARE_FIELD(min_child_weight)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(1.0f)
|
||||
|
||||
Reference in New Issue
Block a user