Fuse split evaluation kernels (#8026)

This commit is contained in:
Rory Mitchell
2022-07-05 10:24:31 +02:00
committed by GitHub
parent ff1c559084
commit 794cbaa60a
6 changed files with 308 additions and 314 deletions

View File

@@ -21,6 +21,7 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
int32_t device) {
param_ = param;
tree_evaluator_ = TreeEvaluator{param, n_features, device};
has_categoricals_ = cuts.HasCategorical();
if (cuts.HasCategorical()) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
@@ -69,42 +70,46 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
template <typename GradientSumT>
common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
EvaluateSplitInputs<GradientSumT> const &left, EvaluateSplitInputs<GradientSumT> const &right,
common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto sorted_idx = this->SortedIdx(left);
auto sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
dh::Iota(sorted_idx);
auto data = this->SortInput(left);
auto data = this->SortInput(d_inputs.size(), shared_inputs.feature_values.size());
auto it = thrust::make_counting_iterator(0u);
auto d_feature_idx = dh::ToSpan(feature_idx_);
auto total_bins = shared_inputs.feature_values.size();
thrust::transform(thrust::cuda::par(alloc), it, it + data.size(), dh::tbegin(data),
[=] XGBOOST_DEVICE(uint32_t i) {
auto is_left = i < left.feature_values.size();
auto const &input = is_left ? left : right;
auto j = i - (is_left ? 0 : input.feature_values.size());
auto const &input = d_inputs[i / total_bins];
auto j = i % total_bins;
auto fidx = d_feature_idx[j];
if (common::IsCat(input.feature_types, fidx)) {
auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[j]);
if (common::IsCat(shared_inputs.feature_types, fidx)) {
auto lw = evaluator.CalcWeightCat(shared_inputs.param,
input.gradient_histogram[j]);
return thrust::make_tuple(i, lw);
}
return thrust::make_tuple(i, 0.0);
});
// Sort an array segmented according to
// - nodes
// - features within each node
// - gradients within each feature
thrust::stable_sort_by_key(thrust::cuda::par(alloc), dh::tbegin(data), dh::tend(data),
dh::tbegin(sorted_idx),
[=] XGBOOST_DEVICE(SortPair const &l, SortPair const &r) {
auto li = thrust::get<0>(l);
auto ri = thrust::get<0>(r);
auto l_is_left = li < left.feature_values.size();
auto r_is_left = ri < left.feature_values.size();
auto l_node = li / total_bins;
auto r_node = ri / total_bins;
if (l_is_left != r_is_left) {
return l_is_left; // not the same node
if (l_node != r_node) {
return l_node < r_node; // not the same node
}
auto const &input = l_is_left ? left : right;
li -= (l_is_left ? 0 : input.feature_values.size());
ri -= (r_is_left ? 0 : input.feature_values.size());
li = li % total_bins;
ri = ri % total_bins;
auto lfidx = d_feature_idx[li];
auto rfidx = d_feature_idx[ri];
@@ -113,7 +118,7 @@ common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
return lfidx < rfidx; // not the same feature
}
if (common::IsCat(input.feature_types, lfidx)) {
if (common::IsCat(shared_inputs.feature_types, lfidx)) {
auto lw = thrust::get<1>(l);
auto rw = thrust::get<1>(r);
return lw < rw;