Unify the cat split storage for CPU. (#7937)
* Unify the cat split storage for CPU. * Cleanup. * Workaround.
This commit is contained in:
parent
755d9d4609
commit
18cbebaeb9
@ -440,10 +440,10 @@ class RegTree : public Model {
|
|||||||
* \param right_sum The sum hess of right leaf.
|
* \param right_sum The sum hess of right leaf.
|
||||||
*/
|
*/
|
||||||
void ExpandCategorical(bst_node_t nid, unsigned split_index,
|
void ExpandCategorical(bst_node_t nid, unsigned split_index,
|
||||||
common::Span<uint32_t> split_cat, bool default_left,
|
common::Span<const uint32_t> split_cat, bool default_left,
|
||||||
bst_float base_weight, bst_float left_leaf_weight,
|
bst_float base_weight, bst_float left_leaf_weight,
|
||||||
bst_float right_leaf_weight, bst_float loss_change,
|
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||||
float sum_hess, float left_sum, float right_sum);
|
float left_sum, float right_sum);
|
||||||
|
|
||||||
bool HasCategoricalSplit() const {
|
bool HasCategoricalSplit() const {
|
||||||
return !split_categories_.empty();
|
return !split_categories_.empty();
|
||||||
|
|||||||
@ -57,9 +57,9 @@ class HistEvaluator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
|
* \brief Use learned direction with one-hot split. Other implementations (LGB) create a
|
||||||
* create a pseudo-category for missing value but here we just do a complete scan
|
* pseudo-category for missing value but here we just do a complete scan to avoid
|
||||||
* to avoid making specialized histogram bin.
|
* making specialized histogram bin.
|
||||||
*/
|
*/
|
||||||
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
|
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
|
||||||
bst_feature_t fidx, bst_node_t nidx,
|
bst_feature_t fidx, bst_node_t nidx,
|
||||||
@ -76,6 +76,7 @@ class HistEvaluator {
|
|||||||
GradStats right_sum;
|
GradStats right_sum;
|
||||||
// best split so far
|
// best split so far
|
||||||
SplitEntry best;
|
SplitEntry best;
|
||||||
|
best.is_cat = false; // marker for whether it's updated or not.
|
||||||
|
|
||||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
||||||
auto feature_sum = GradStats{
|
auto feature_sum = GradStats{
|
||||||
@ -98,8 +99,8 @@ class HistEvaluator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// missing on right (treat missing as chosen category)
|
// missing on right (treat missing as chosen category)
|
||||||
left_sum.SetSubstract(left_sum, missing);
|
|
||||||
right_sum.Add(missing);
|
right_sum.Add(missing);
|
||||||
|
left_sum.SetSubstract(parent.stats, right_sum);
|
||||||
if (IsValid(left_sum, right_sum)) {
|
if (IsValid(left_sum, right_sum)) {
|
||||||
auto missing_right_chg = static_cast<float>(
|
auto missing_right_chg = static_cast<float>(
|
||||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||||
@ -108,6 +109,13 @@ class HistEvaluator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (best.is_cat) {
|
||||||
|
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
|
||||||
|
best.cat_bits.resize(n, 0);
|
||||||
|
common::CatBitField cat_bits{best.cat_bits};
|
||||||
|
cat_bits.Set(best.split_value);
|
||||||
|
}
|
||||||
|
|
||||||
p_best->Update(best);
|
p_best->Update(best);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -345,25 +353,11 @@ class HistEvaluator {
|
|||||||
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});
|
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});
|
||||||
|
|
||||||
if (candidate.split.is_cat) {
|
if (candidate.split.is_cat) {
|
||||||
std::vector<uint32_t> split_cats;
|
|
||||||
if (candidate.split.cat_bits.empty()) {
|
|
||||||
if (common::InvalidCat(candidate.split.split_value)) {
|
|
||||||
common::InvalidCategory();
|
|
||||||
}
|
|
||||||
auto cat = common::AsCat(candidate.split.split_value);
|
|
||||||
split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
|
|
||||||
LBitField32 cat_bits;
|
|
||||||
cat_bits = LBitField32(split_cats);
|
|
||||||
cat_bits.Set(cat);
|
|
||||||
} else {
|
|
||||||
split_cats = candidate.split.cat_bits;
|
|
||||||
common::CatBitField cat_bits{split_cats};
|
|
||||||
}
|
|
||||||
tree.ExpandCategorical(
|
tree.ExpandCategorical(
|
||||||
candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(),
|
candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
|
||||||
base_weight, left_weight * param_.learning_rate, right_weight * param_.learning_rate,
|
candidate.split.DefaultLeft(), base_weight, left_weight * param_.learning_rate,
|
||||||
candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.left_sum.GetHess(),
|
right_weight * param_.learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||||
candidate.split.right_sum.GetHess());
|
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||||
} else {
|
} else {
|
||||||
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
|
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
|
||||||
candidate.split.DefaultLeft(), base_weight,
|
candidate.split.DefaultLeft(), base_weight,
|
||||||
|
|||||||
@ -160,7 +160,7 @@ class TreeEvaluator {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto max_nidx = std::max(leftid, rightid);
|
size_t max_nidx = std::max(leftid, rightid);
|
||||||
if (lower_bounds_.Size() <= max_nidx) {
|
if (lower_bounds_.Size() <= max_nidx) {
|
||||||
lower_bounds_.Resize(max_nidx * 2 + 1, -std::numeric_limits<float>::max());
|
lower_bounds_.Resize(max_nidx * 2 + 1, -std::numeric_limits<float>::max());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -808,11 +808,9 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
|
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
|
||||||
common::Span<uint32_t> split_cat, bool default_left,
|
common::Span<const uint32_t> split_cat, bool default_left,
|
||||||
bst_float base_weight,
|
bst_float base_weight, bst_float left_leaf_weight,
|
||||||
bst_float left_leaf_weight,
|
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||||
bst_float right_leaf_weight,
|
|
||||||
bst_float loss_change, float sum_hess,
|
|
||||||
float left_sum, float right_sum) {
|
float left_sum, float right_sum) {
|
||||||
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
|
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
|
||||||
default_left, base_weight,
|
default_left, base_weight,
|
||||||
|
|||||||
@ -10,10 +10,10 @@ exact_parameter_strategy = strategies.fixed_dictionaries({
|
|||||||
'nthread': strategies.integers(1, 4),
|
'nthread': strategies.integers(1, 4),
|
||||||
'max_depth': strategies.integers(1, 11),
|
'max_depth': strategies.integers(1, 11),
|
||||||
'min_child_weight': strategies.floats(0.5, 2.0),
|
'min_child_weight': strategies.floats(0.5, 2.0),
|
||||||
'alpha': strategies.floats(0.0, 2.0),
|
'alpha': strategies.floats(1e-5, 2.0),
|
||||||
'lambda': strategies.floats(1e-5, 2.0),
|
'lambda': strategies.floats(1e-5, 2.0),
|
||||||
'eta': strategies.floats(0.01, 0.5),
|
'eta': strategies.floats(0.01, 0.5),
|
||||||
'gamma': strategies.floats(0.0, 2.0),
|
'gamma': strategies.floats(1e-5, 2.0),
|
||||||
'seed': strategies.integers(0, 10),
|
'seed': strategies.integers(0, 10),
|
||||||
# We cannot enable subsampling as the training loss can increase
|
# We cannot enable subsampling as the training loss can increase
|
||||||
# 'subsample': strategies.floats(0.5, 1.0),
|
# 'subsample': strategies.floats(0.5, 1.0),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user