enable rocm, fix threading_utils.cuh
This commit is contained in:
parent
327f1494f1
commit
2eb0b6aae4
@ -62,9 +62,17 @@ SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
|
||||
dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(),
|
||||
out_group_threads_ptr.size());
|
||||
size_t total = 0;
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(
|
||||
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
|
||||
sizeof(total), hipMemcpyDeviceToHost));
|
||||
#else
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
|
||||
sizeof(total), cudaMemcpyDeviceToHost));
|
||||
#endif
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user