enable rocm, fix cuda_context.cuh

This commit is contained in:
amdsc21 2023-03-08 06:29:45 +01:00
parent fa92aa56ee
commit 327f1494f1

View File

@ -17,11 +17,21 @@ struct CUDAContext {
/**
* \brief Caching thrust policy.
*/
#if defined(XGBOOST_USE_HIP)
auto CTP() const { return thrust::hip::par(caching_alloc_).on(dh::DefaultStream()); }
#else
auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); }
#endif
/**
* \brief Thrust policy without caching allocator.
*/
#if defined(XGBOOST_USE_HIP)
auto TP() const { return thrust::hip::par(alloc_).on(dh::DefaultStream()); }
#else
auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); }
#endif
auto Stream() const { return dh::DefaultStream(); }
};
} // namespace xgboost