diff --git a/src/common/cuda_context.cuh b/src/common/cuda_context.cuh index 9056c1b5e..372b49dde 100644 --- a/src/common/cuda_context.cuh +++ b/src/common/cuda_context.cuh @@ -17,11 +17,21 @@ struct CUDAContext { /** * \brief Caching thrust policy. */ +#if defined(XGBOOST_USE_HIP) + auto CTP() const { return thrust::hip::par(caching_alloc_).on(dh::DefaultStream()); } +#else auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); } +#endif + /** * \brief Thrust policy without caching allocator. */ +#if defined(XGBOOST_USE_HIP) + auto TP() const { return thrust::hip::par(alloc_).on(dh::DefaultStream()); } +#else auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); } +#endif + auto Stream() const { return dh::DefaultStream(); } }; } // namespace xgboost