Fuse split evaluation kernels (#8026)
This commit is contained in:
@@ -21,6 +21,7 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
||||
int32_t device) {
|
||||
param_ = param;
|
||||
tree_evaluator_ = TreeEvaluator{param, n_features, device};
|
||||
has_categoricals_ = cuts.HasCategorical();
|
||||
if (cuts.HasCategorical()) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
|
||||
@@ -69,42 +70,46 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
||||
|
||||
template <typename GradientSumT>
|
||||
common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
|
||||
EvaluateSplitInputs<GradientSumT> const &left, EvaluateSplitInputs<GradientSumT> const &right,
|
||||
common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto sorted_idx = this->SortedIdx(left);
|
||||
auto sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
|
||||
dh::Iota(sorted_idx);
|
||||
auto data = this->SortInput(left);
|
||||
auto data = this->SortInput(d_inputs.size(), shared_inputs.feature_values.size());
|
||||
auto it = thrust::make_counting_iterator(0u);
|
||||
auto d_feature_idx = dh::ToSpan(feature_idx_);
|
||||
auto total_bins = shared_inputs.feature_values.size();
|
||||
thrust::transform(thrust::cuda::par(alloc), it, it + data.size(), dh::tbegin(data),
|
||||
[=] XGBOOST_DEVICE(uint32_t i) {
|
||||
auto is_left = i < left.feature_values.size();
|
||||
auto const &input = is_left ? left : right;
|
||||
auto j = i - (is_left ? 0 : input.feature_values.size());
|
||||
auto const &input = d_inputs[i / total_bins];
|
||||
auto j = i % total_bins;
|
||||
auto fidx = d_feature_idx[j];
|
||||
if (common::IsCat(input.feature_types, fidx)) {
|
||||
auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[j]);
|
||||
if (common::IsCat(shared_inputs.feature_types, fidx)) {
|
||||
auto lw = evaluator.CalcWeightCat(shared_inputs.param,
|
||||
input.gradient_histogram[j]);
|
||||
return thrust::make_tuple(i, lw);
|
||||
}
|
||||
return thrust::make_tuple(i, 0.0);
|
||||
});
|
||||
// Sort an array segmented according to
|
||||
// - nodes
|
||||
// - features within each node
|
||||
// - gradients within each feature
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc), dh::tbegin(data), dh::tend(data),
|
||||
dh::tbegin(sorted_idx),
|
||||
[=] XGBOOST_DEVICE(SortPair const &l, SortPair const &r) {
|
||||
auto li = thrust::get<0>(l);
|
||||
auto ri = thrust::get<0>(r);
|
||||
|
||||
auto l_is_left = li < left.feature_values.size();
|
||||
auto r_is_left = ri < left.feature_values.size();
|
||||
auto l_node = li / total_bins;
|
||||
auto r_node = ri / total_bins;
|
||||
|
||||
if (l_is_left != r_is_left) {
|
||||
return l_is_left; // not the same node
|
||||
if (l_node != r_node) {
|
||||
return l_node < r_node; // not the same node
|
||||
}
|
||||
|
||||
auto const &input = l_is_left ? left : right;
|
||||
li -= (l_is_left ? 0 : input.feature_values.size());
|
||||
ri -= (r_is_left ? 0 : input.feature_values.size());
|
||||
li = li % total_bins;
|
||||
ri = ri % total_bins;
|
||||
|
||||
auto lfidx = d_feature_idx[li];
|
||||
auto rfidx = d_feature_idx[ri];
|
||||
@@ -113,7 +118,7 @@ common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
|
||||
return lfidx < rfidx; // not the same feature
|
||||
}
|
||||
|
||||
if (common::IsCat(input.feature_types, lfidx)) {
|
||||
if (common::IsCat(shared_inputs.feature_types, lfidx)) {
|
||||
auto lw = thrust::get<1>(l);
|
||||
auto rw = thrust::get<1>(r);
|
||||
return lw < rw;
|
||||
|
||||
Reference in New Issue
Block a user