Unify the cat split storage for CPU. (#7937)
* Unify the cat split storage for CPU. * Cleanup. * Workaround.
This commit is contained in:
@@ -57,9 +57,9 @@ class HistEvaluator {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
|
||||
* create a pseudo-category for missing value but here we just do a complete scan
|
||||
* to avoid making specialized histogram bin.
|
||||
* \brief Use learned direction with one-hot split. Other implementations (LGB) create a
|
||||
* pseudo-category for missing value but here we just do a complete scan to avoid
|
||||
* making specialized histogram bin.
|
||||
*/
|
||||
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
|
||||
bst_feature_t fidx, bst_node_t nidx,
|
||||
@@ -76,6 +76,7 @@ class HistEvaluator {
|
||||
GradStats right_sum;
|
||||
// best split so far
|
||||
SplitEntry best;
|
||||
best.is_cat = false; // marker for whether it's updated or not.
|
||||
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
||||
auto feature_sum = GradStats{
|
||||
@@ -98,8 +99,8 @@ class HistEvaluator {
|
||||
}
|
||||
|
||||
// missing on right (treat missing as chosen category)
|
||||
left_sum.SetSubstract(left_sum, missing);
|
||||
right_sum.Add(missing);
|
||||
left_sum.SetSubstract(parent.stats, right_sum);
|
||||
if (IsValid(left_sum, right_sum)) {
|
||||
auto missing_right_chg = static_cast<float>(
|
||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||
@@ -108,6 +109,13 @@ class HistEvaluator {
|
||||
}
|
||||
}
|
||||
|
||||
if (best.is_cat) {
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
|
||||
best.cat_bits.resize(n, 0);
|
||||
common::CatBitField cat_bits{best.cat_bits};
|
||||
cat_bits.Set(best.split_value);
|
||||
}
|
||||
|
||||
p_best->Update(best);
|
||||
}
|
||||
|
||||
@@ -345,25 +353,11 @@ class HistEvaluator {
|
||||
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});
|
||||
|
||||
if (candidate.split.is_cat) {
|
||||
std::vector<uint32_t> split_cats;
|
||||
if (candidate.split.cat_bits.empty()) {
|
||||
if (common::InvalidCat(candidate.split.split_value)) {
|
||||
common::InvalidCategory();
|
||||
}
|
||||
auto cat = common::AsCat(candidate.split.split_value);
|
||||
split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
|
||||
LBitField32 cat_bits;
|
||||
cat_bits = LBitField32(split_cats);
|
||||
cat_bits.Set(cat);
|
||||
} else {
|
||||
split_cats = candidate.split.cat_bits;
|
||||
common::CatBitField cat_bits{split_cats};
|
||||
}
|
||||
tree.ExpandCategorical(
|
||||
candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(),
|
||||
base_weight, left_weight * param_.learning_rate, right_weight * param_.learning_rate,
|
||||
candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.left_sum.GetHess(),
|
||||
candidate.split.right_sum.GetHess());
|
||||
candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
|
||||
candidate.split.DefaultLeft(), base_weight, left_weight * param_.learning_rate,
|
||||
right_weight * param_.learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||
} else {
|
||||
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
|
||||
candidate.split.DefaultLeft(), base_weight,
|
||||
|
||||
@@ -160,7 +160,7 @@ class TreeEvaluator {
|
||||
return;
|
||||
}
|
||||
|
||||
auto max_nidx = std::max(leftid, rightid);
|
||||
size_t max_nidx = std::max(leftid, rightid);
|
||||
if (lower_bounds_.Size() <= max_nidx) {
|
||||
lower_bounds_.Resize(max_nidx * 2 + 1, -std::numeric_limits<float>::max());
|
||||
}
|
||||
|
||||
@@ -808,11 +808,9 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
|
||||
}
|
||||
|
||||
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
|
||||
common::Span<uint32_t> split_cat, bool default_left,
|
||||
bst_float base_weight,
|
||||
bst_float left_leaf_weight,
|
||||
bst_float right_leaf_weight,
|
||||
bst_float loss_change, float sum_hess,
|
||||
common::Span<const uint32_t> split_cat, bool default_left,
|
||||
bst_float base_weight, bst_float left_leaf_weight,
|
||||
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||
float left_sum, float right_sum) {
|
||||
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
|
||||
default_left, base_weight,
|
||||
|
||||
Reference in New Issue
Block a user