enable rocm, fix stats.cuh
This commit is contained in:
parent
60795f22de
commit
ca8f4e7993
@ -216,8 +216,14 @@ void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_b
|
||||
detail::SegOp<SegIt>{seg_beg, seg_end});
|
||||
auto scan_val = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||
detail::WeightOp<WIter>{w_begin, d_sorted_idx});
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights,
|
||||
scan_val, weights_cdf.begin());
|
||||
#else
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
|
||||
scan_val, weights_cdf.begin());
|
||||
#endif
|
||||
|
||||
auto n_segments = std::distance(seg_beg, seg_end) - 1;
|
||||
quantiles->SetDevice(ctx->gpu_id);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user