Handle missing values in one hot splits. (#7934)

This commit is contained in:
Jiaming Yuan
2022-05-24 20:48:41 +08:00
committed by GitHub
parent 18a38f7ca0
commit 606be9e663
3 changed files with 105 additions and 14 deletions

View File

@@ -45,14 +45,72 @@ class HistEvaluator {
// then - there are no missing values
// else - there are missing values
bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
if (e.GetGrad() == snode.stats.GetGrad() &&
e.GetHess() == snode.stats.GetHess()) {
if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
return false;
} else {
return true;
}
}
bool IsValid(GradStats const &left, GradStats const &right) const {
return left.GetHess() >= param_.min_child_weight && right.GetHess() >= param_.min_child_weight;
}
/**
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
* create a pseudo-category for missing value but here we just do a complete scan
* to avoid making specialized histogram bin.
*/
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) const {
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
const std::vector<bst_float> &cut_val = cut.Values();
bst_bin_t ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
bst_bin_t iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
bst_bin_t n_bins = iend - ibegin;
GradStats left_sum;
GradStats right_sum;
// best split so far
SplitEntry best;
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
auto feature_sum = GradStats{
std::accumulate(f_hist.data(), f_hist.data() + f_hist.size(), GradientPairPrecise{})};
GradStats missing;
auto const &parent = snode_[nidx];
missing.SetSubstract(parent.stats, feature_sum);
for (bst_bin_t i = ibegin; i != iend; i += 1) {
auto split_pt = cut_val[i];
// missing on left (treat missing as other categories)
right_sum = GradStats{hist[i]};
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_left_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
}
// missing on right (treat missing as chosen category)
left_sum.SetSubstract(left_sum, missing);
right_sum.Add(missing);
if (IsValid(left_sum, right_sum)) {
auto missing_right_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
}
}
p_best->Update(best);
}
// Enumerate/Scan the split values of specific feature
// Returns the sum of gradients corresponding to the data points that contains
// a non-missing value for the particular feature fid.
@@ -102,9 +160,7 @@ class HistEvaluator {
break;
}
case kOneHot: {
// not-chosen categories go to left
right_sum = GradStats{hist[i]};
left_sum.SetSubstract(parent.stats, right_sum);
std::terminate(); // unreachable
break;
}
case kPart: {
@@ -151,7 +207,7 @@ class HistEvaluator {
break;
}
case kOneHot: {
split_pt = cut_val[i];
std::terminate(); // unreachable
break;
}
case kPart: {
@@ -188,7 +244,6 @@ class HistEvaluator {
// Normal, accumulated to left
return left_sum;
case kOneHot:
// Doesn't matter, not accumulating.
return {};
case kPart:
// Accumulated to right due to chosen cats go to right.
@@ -242,8 +297,7 @@ class HistEvaluator {
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);