Extract device algorithms. (#8789)

This commit is contained in:
Jiaming Yuan
2023-02-13 20:53:53 +08:00
committed by GitHub
parent 457f704e3d
commit 31d3ec07af
13 changed files with 361 additions and 218 deletions

View File

@@ -17,6 +17,7 @@
#include <limits> // std::numeric_limits
#include <type_traits> // std::is_floating_point,std::iterator_traits
#include "algorithm.cuh" // SegmentedArgMergeSort
#include "cuda_context.cuh" // CUDAContext
#include "device_helpers.cuh"
#include "xgboost/context.h" // Context
@@ -150,7 +151,7 @@ void SegmentedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_begin, Se
ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) {
dh::device_vector<std::size_t> sorted_idx;
using Tup = thrust::tuple<std::size_t, float>;
dh::SegmentedArgSort(seg_begin, seg_end, val_begin, val_end, &sorted_idx);
common::SegmentedArgMergeSort(ctx, seg_begin, seg_end, val_begin, val_end, &sorted_idx);
auto n_segments = std::distance(seg_begin, seg_end) - 1;
if (n_segments <= 0) {
return;
@@ -203,7 +204,7 @@ void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_b
HostDeviceVector<float>* quantiles) {
auto cuctx = ctx->CUDACtx();
dh::device_vector<std::size_t> sorted_idx;
dh::SegmentedArgSort(seg_beg, seg_end, val_begin, val_end, &sorted_idx);
common::SegmentedArgMergeSort(ctx, seg_beg, seg_end, val_begin, val_end, &sorted_idx);
auto d_sorted_idx = dh::ToSpan(sorted_idx);
std::size_t n_weights = std::distance(w_begin, w_end);
dh::device_vector<float> weights_cdf(n_weights);