From 60795f22deae2c80aeb4925f0677264104083ef7 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 06:42:20 +0100 Subject: [PATCH] enable rocm, fix linalg_op.cuh --- src/common/linalg_op.cuh | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index 037ad1ff3..941de49c5 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -12,8 +12,18 @@ namespace xgboost { namespace linalg { template -void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { +#if defined(XGBOOST_USE_HIP) +void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, hipStream_t s = nullptr) +#else +void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) +#endif +{ +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(t.DeviceIdx())); +#else dh::safe_cuda(cudaSetDevice(t.DeviceIdx())); +#endif + static_assert(std::is_void>::value, "For function with return, use transform instead."); if (t.Contiguous()) { @@ -28,7 +38,12 @@ void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s } template -void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { +#if defined(XGBOOST_USE_HIP) +void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, hipStream_t s = nullptr) +#else +void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) +#endif +{ if (t.Contiguous()) { auto ptr = t.Values().data(); dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });