From 327f1494f1a5131a518104f7b6bdff19108197c5 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 06:29:45 +0100 Subject: [PATCH] enable rocm, fix cuda_context.cuh --- src/common/cuda_context.cuh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/common/cuda_context.cuh b/src/common/cuda_context.cuh index 9056c1b5e..372b49dde 100644 --- a/src/common/cuda_context.cuh +++ b/src/common/cuda_context.cuh @@ -17,11 +17,21 @@ struct CUDAContext { /** * \brief Caching thrust policy. */ +#if defined(XGBOOST_USE_HIP) + auto CTP() const { return thrust::hip::par(caching_alloc_).on(dh::DefaultStream()); } +#else auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); } +#endif + /** * \brief Thrust policy without caching allocator. */ +#if defined(XGBOOST_USE_HIP) + auto TP() const { return thrust::hip::par(alloc_).on(dh::DefaultStream()); } +#else auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); } +#endif + auto Stream() const { return dh::DefaultStream(); } }; } // namespace xgboost