enable rocm, fix stats.cuh
This commit is contained in:
parent
60795f22de
commit
ca8f4e7993
@ -216,8 +216,14 @@ void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_b
|
|||||||
detail::SegOp<SegIt>{seg_beg, seg_end});
|
detail::SegOp<SegIt>{seg_beg, seg_end});
|
||||||
auto scan_val = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
auto scan_val = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||||
detail::WeightOp<WIter>{w_begin, d_sorted_idx});
|
detail::WeightOp<WIter>{w_begin, d_sorted_idx});
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights,
|
||||||
|
scan_val, weights_cdf.begin());
|
||||||
|
#else
|
||||||
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
|
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
|
||||||
scan_val, weights_cdf.begin());
|
scan_val, weights_cdf.begin());
|
||||||
|
#endif
|
||||||
|
|
||||||
auto n_segments = std::distance(seg_beg, seg_end) - 1;
|
auto n_segments = std::distance(seg_beg, seg_end) - 1;
|
||||||
quantiles->SetDevice(ctx->gpu_id);
|
quantiles->SetDevice(ctx->gpu_id);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user