Unify the cat split storage for CPU. (#7937)

* Unify the cat split storage for CPU.

* Cleanup.

* Workaround.
This commit is contained in:
Jiaming Yuan 2022-05-26 19:14:40 +08:00 committed by GitHub
parent 755d9d4609
commit 18cbebaeb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 33 deletions

View File

@ -440,10 +440,10 @@ class RegTree : public Model {
* \param right_sum The sum hess of right leaf.
*/
void ExpandCategorical(bst_node_t nid, unsigned split_index,
common::Span<uint32_t> split_cat, bool default_left,
common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum);
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
float left_sum, float right_sum);
bool HasCategoricalSplit() const {
return !split_categories_.empty();

View File

@ -57,9 +57,9 @@ class HistEvaluator {
}
/**
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
* create a pseudo-category for missing value but here we just do a complete scan
* to avoid making specialized histogram bin.
* \brief Use learned direction with one-hot split. Other implementations (LGB) create a
* pseudo-category for missing value but here we just do a complete scan to avoid
* making specialized histogram bin.
*/
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
bst_feature_t fidx, bst_node_t nidx,
@ -76,6 +76,7 @@ class HistEvaluator {
GradStats right_sum;
// best split so far
SplitEntry best;
best.is_cat = false; // marker for whether it's updated or not.
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
auto feature_sum = GradStats{
@ -98,8 +99,8 @@ class HistEvaluator {
}
// missing on right (treat missing as chosen category)
left_sum.SetSubstract(left_sum, missing);
right_sum.Add(missing);
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_right_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
@ -108,6 +109,13 @@ class HistEvaluator {
}
}
if (best.is_cat) {
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
best.cat_bits.resize(n, 0);
common::CatBitField cat_bits{best.cat_bits};
cat_bits.Set(best.split_value);
}
p_best->Update(best);
}
@ -345,25 +353,11 @@ class HistEvaluator {
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});
if (candidate.split.is_cat) {
std::vector<uint32_t> split_cats;
if (candidate.split.cat_bits.empty()) {
if (common::InvalidCat(candidate.split.split_value)) {
common::InvalidCategory();
}
auto cat = common::AsCat(candidate.split.split_value);
split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
LBitField32 cat_bits;
cat_bits = LBitField32(split_cats);
cat_bits.Set(cat);
} else {
split_cats = candidate.split.cat_bits;
common::CatBitField cat_bits{split_cats};
}
tree.ExpandCategorical(
candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(),
base_weight, left_weight * param_.learning_rate, right_weight * param_.learning_rate,
candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
candidate.split.DefaultLeft(), base_weight, left_weight * param_.learning_rate,
right_weight * param_.learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
} else {
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight,

View File

@ -160,7 +160,7 @@ class TreeEvaluator {
return;
}
auto max_nidx = std::max(leftid, rightid);
size_t max_nidx = std::max(leftid, rightid);
if (lower_bounds_.Size() <= max_nidx) {
lower_bounds_.Resize(max_nidx * 2 + 1, -std::numeric_limits<float>::max());
}

View File

@ -808,11 +808,9 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
}
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
common::Span<uint32_t> split_cat, bool default_left,
bst_float base_weight,
bst_float left_leaf_weight,
bst_float right_leaf_weight,
bst_float loss_change, float sum_hess,
common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
float left_sum, float right_sum) {
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
default_left, base_weight,

View File

@ -10,10 +10,10 @@ exact_parameter_strategy = strategies.fixed_dictionaries({
'nthread': strategies.integers(1, 4),
'max_depth': strategies.integers(1, 11),
'min_child_weight': strategies.floats(0.5, 2.0),
'alpha': strategies.floats(0.0, 2.0),
'alpha': strategies.floats(1e-5, 2.0),
'lambda': strategies.floats(1e-5, 2.0),
'eta': strategies.floats(0.01, 0.5),
'gamma': strategies.floats(0.0, 2.0),
'gamma': strategies.floats(1e-5, 2.0),
'seed': strategies.integers(0, 10),
# We cannot enable subsampling as the training loss can increase
# 'subsample': strategies.floats(0.5, 1.0),