From 2eb0b6aae46de580a0c111856cb94ec074720d51 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 06:30:52 +0100 Subject: [PATCH] enable rocm, fix threading_utils.cuh --- src/common/threading_utils.cuh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/common/threading_utils.cuh b/src/common/threading_utils.cuh index c21d312d2..5ff78144d 100644 --- a/src/common/threading_utils.cuh +++ b/src/common/threading_utils.cuh @@ -62,9 +62,17 @@ SegmentedTrapezoidThreads(xgboost::common::Span group_ptr, dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(), out_group_threads_ptr.size()); size_t total = 0; + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpy( + &total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1, + sizeof(total), hipMemcpyDeviceToHost)); +#else dh::safe_cuda(cudaMemcpy( &total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1, sizeof(total), cudaMemcpyDeviceToHost)); +#endif + return total; }