diff --git a/src/common/threading_utils.cuh b/src/common/threading_utils.cuh index c21d312d2..5ff78144d 100644 --- a/src/common/threading_utils.cuh +++ b/src/common/threading_utils.cuh @@ -62,9 +62,17 @@ SegmentedTrapezoidThreads(xgboost::common::Span group_ptr, dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(), out_group_threads_ptr.size()); size_t total = 0; + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpy( + &total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1, + sizeof(total), hipMemcpyDeviceToHost)); +#else dh::safe_cuda(cudaMemcpy( &total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1, sizeof(total), cudaMemcpyDeviceToHost)); +#endif + return total; }