enable rocm, fix threading_utils.cuh
This commit is contained in:
parent
327f1494f1
commit
2eb0b6aae4
@ -62,9 +62,17 @@ SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
|
|||||||
dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(),
|
dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(),
|
||||||
out_group_threads_ptr.size());
|
out_group_threads_ptr.size());
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipMemcpy(
|
||||||
|
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
|
||||||
|
sizeof(total), hipMemcpyDeviceToHost));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaMemcpy(
|
dh::safe_cuda(cudaMemcpy(
|
||||||
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
|
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
|
||||||
sizeof(total), cudaMemcpyDeviceToHost));
|
sizeof(total), cudaMemcpyDeviceToHost));
|
||||||
|
#endif
|
||||||
|
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user