enable rocm, fix threading_utils.cuh

This commit is contained in:
amdsc21 2023-03-08 06:30:52 +01:00
parent 327f1494f1
commit 2eb0b6aae4

View File

@ -62,9 +62,17 @@ SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(),
out_group_threads_ptr.size());
size_t total = 0;
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpy(
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
sizeof(total), hipMemcpyDeviceToHost));
#else
dh::safe_cuda(cudaMemcpy(
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
sizeof(total), cudaMemcpyDeviceToHost));
#endif
return total;
}