diff --git a/src/common/stats.cuh b/src/common/stats.cuh index f31233461..28115abef 100644 --- a/src/common/stats.cuh +++ b/src/common/stats.cuh @@ -216,8 +216,14 @@ void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_b detail::SegOp{seg_beg, seg_end}); auto scan_val = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), detail::WeightOp{w_begin, d_sorted_idx}); + +#if defined(XGBOOST_USE_HIP) + thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights, + scan_val, weights_cdf.begin()); +#else thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights, scan_val, weights_cdf.begin()); +#endif auto n_segments = std::distance(seg_beg, seg_end) - 1; quantiles->SetDevice(ctx->gpu_id);