Support optimal partitioning for GPU hist. (#7652)
* Implement `MaxCategory` in quantile. * Implement partition-based split for GPU evaluation. Currently, it's based on the existing evaluation function. * Extract an evaluator from GPU Hist to store the needed states. * Added some CUDA stream/event utilities. * Update document with references. * Fixed a bug in approx evaluator where the number of data points is less than the number of categories.
This commit is contained in:
@@ -53,7 +53,6 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
enum SplitType { kNum = 0, kOneHot = 1, kPart = 2 };
|
||||
|
||||
// Enumerate/Scan the split values of specific feature
|
||||
// Returns the sum of gradients corresponding to the data points that contains
|
||||
@@ -137,7 +136,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum},
|
||||
GradStats{right_sum}) -
|
||||
parent.root_gain);
|
||||
split_pt = cut_val[i];
|
||||
split_pt = cut_val[i]; // not used for partition based
|
||||
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
|
||||
left_sum, right_sum);
|
||||
} else {
|
||||
@@ -180,10 +179,10 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
|
||||
if (d_step == 1) {
|
||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1),
|
||||
[&cat_bits](size_t c) { cat_bits.Set(c); });
|
||||
[&](size_t c) { cat_bits.Set(cut_val[c + ibegin]); });
|
||||
} else {
|
||||
std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh),
|
||||
[&cat_bits](size_t c) { cat_bits.Set(c); });
|
||||
[&](size_t c) { cat_bits.Set(cut_val[c + cut_ptr[fidx]]); });
|
||||
}
|
||||
}
|
||||
p_best->Update(best);
|
||||
@@ -231,6 +230,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
}
|
||||
}
|
||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||
auto const& cut_ptrs = cut.Ptrs();
|
||||
|
||||
common::ParallelFor2d(space, n_threads_, [&](size_t nidx_in_set, common::Range1d r) {
|
||||
auto tidx = omp_get_thread_num();
|
||||
@@ -246,26 +246,22 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
continue;
|
||||
}
|
||||
if (is_cat) {
|
||||
auto n_bins = cut.Ptrs().at(fidx + 1) - cut.Ptrs()[fidx];
|
||||
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
|
||||
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) {
|
||||
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||
} else {
|
||||
auto const &cut_ptr = cut.Ptrs();
|
||||
std::vector<size_t> sorted_idx(n_bins);
|
||||
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
|
||||
auto feat_hist = histogram.subspan(cut_ptr[fidx], n_bins);
|
||||
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
|
||||
// Sort the histogram to get contiguous partitions.
|
||||
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
|
||||
auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) <
|
||||
evaluator.CalcWeightCat(param_, feat_hist[r]);
|
||||
static_assert(std::is_same<decltype(ret), bool>::value, "");
|
||||
return ret;
|
||||
});
|
||||
auto grad_stats =
|
||||
EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
|
||||
EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||
}
|
||||
EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||
EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||
}
|
||||
} else {
|
||||
auto grad_stats =
|
||||
@@ -313,6 +309,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
cat_bits.Set(cat);
|
||||
} else {
|
||||
split_cats = candidate.split.cat_bits;
|
||||
common::CatBitField cat_bits{split_cats};
|
||||
}
|
||||
|
||||
tree.ExpandCategorical(
|
||||
|
||||
Reference in New Issue
Block a user