Handle missing values in one hot splits. (#7934)
This commit is contained in:
@@ -45,14 +45,72 @@ class HistEvaluator {
|
||||
// then - there are no missing values
|
||||
// else - there are missing values
|
||||
bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
|
||||
if (e.GetGrad() == snode.stats.GetGrad() &&
|
||||
e.GetHess() == snode.stats.GetHess()) {
|
||||
if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsValid(GradStats const &left, GradStats const &right) const {
|
||||
return left.GetHess() >= param_.min_child_weight && right.GetHess() >= param_.min_child_weight;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
|
||||
* create a pseudo-category for missing value but here we just do a complete scan
|
||||
* to avoid making specialized histogram bin.
|
||||
*/
|
||||
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
|
||||
bst_feature_t fidx, bst_node_t nidx,
|
||||
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
|
||||
SplitEntry *p_best) const {
|
||||
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
|
||||
const std::vector<bst_float> &cut_val = cut.Values();
|
||||
|
||||
bst_bin_t ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
|
||||
bst_bin_t iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
|
||||
bst_bin_t n_bins = iend - ibegin;
|
||||
|
||||
GradStats left_sum;
|
||||
GradStats right_sum;
|
||||
// best split so far
|
||||
SplitEntry best;
|
||||
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
||||
auto feature_sum = GradStats{
|
||||
std::accumulate(f_hist.data(), f_hist.data() + f_hist.size(), GradientPairPrecise{})};
|
||||
GradStats missing;
|
||||
auto const &parent = snode_[nidx];
|
||||
missing.SetSubstract(parent.stats, feature_sum);
|
||||
|
||||
for (bst_bin_t i = ibegin; i != iend; i += 1) {
|
||||
auto split_pt = cut_val[i];
|
||||
|
||||
// missing on left (treat missing as other categories)
|
||||
right_sum = GradStats{hist[i]};
|
||||
left_sum.SetSubstract(parent.stats, right_sum);
|
||||
if (IsValid(left_sum, right_sum)) {
|
||||
auto missing_left_chg = static_cast<float>(
|
||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||
parent.root_gain);
|
||||
best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
|
||||
}
|
||||
|
||||
// missing on right (treat missing as chosen category)
|
||||
left_sum.SetSubstract(left_sum, missing);
|
||||
right_sum.Add(missing);
|
||||
if (IsValid(left_sum, right_sum)) {
|
||||
auto missing_right_chg = static_cast<float>(
|
||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||
parent.root_gain);
|
||||
best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
|
||||
}
|
||||
}
|
||||
|
||||
p_best->Update(best);
|
||||
}
|
||||
|
||||
// Enumerate/Scan the split values of specific feature
|
||||
// Returns the sum of gradients corresponding to the data points that contains
|
||||
// a non-missing value for the particular feature fid.
|
||||
@@ -102,9 +160,7 @@ class HistEvaluator {
|
||||
break;
|
||||
}
|
||||
case kOneHot: {
|
||||
// not-chosen categories go to left
|
||||
right_sum = GradStats{hist[i]};
|
||||
left_sum.SetSubstract(parent.stats, right_sum);
|
||||
std::terminate(); // unreachable
|
||||
break;
|
||||
}
|
||||
case kPart: {
|
||||
@@ -151,7 +207,7 @@ class HistEvaluator {
|
||||
break;
|
||||
}
|
||||
case kOneHot: {
|
||||
split_pt = cut_val[i];
|
||||
std::terminate(); // unreachable
|
||||
break;
|
||||
}
|
||||
case kPart: {
|
||||
@@ -188,7 +244,6 @@ class HistEvaluator {
|
||||
// Normal, accumulated to left
|
||||
return left_sum;
|
||||
case kOneHot:
|
||||
// Doesn't matter, not accumulating.
|
||||
return {};
|
||||
case kPart:
|
||||
// Accumulated to right due to chosen cats go to right.
|
||||
@@ -242,8 +297,7 @@ class HistEvaluator {
|
||||
if (is_cat) {
|
||||
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
|
||||
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
|
||||
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
|
||||
} else {
|
||||
std::vector<size_t> sorted_idx(n_bins);
|
||||
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
|
||||
|
||||
Reference in New Issue
Block a user