enable rocm, fix stats.cuh

This commit is contained in:
amdsc21 2023-03-08 06:43:06 +01:00
parent 60795f22de
commit ca8f4e7993

View File

@ -216,8 +216,14 @@ void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_b
detail::SegOp<SegIt>{seg_beg, seg_end});
auto scan_val = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
detail::WeightOp<WIter>{w_begin, d_sorted_idx});
#if defined(XGBOOST_USE_HIP)
thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights,
scan_val, weights_cdf.begin());
#else
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
scan_val, weights_cdf.begin());
#endif
auto n_segments = std::distance(seg_beg, seg_end) - 1;
quantiles->SetDevice(ctx->gpu_id);