enable rocm, fix linalg_op.cuh

This commit is contained in:
amdsc21 2023-03-08 06:42:20 +01:00
parent 05fdca893f
commit 60795f22de

View File

@ -12,8 +12,18 @@
namespace xgboost {
namespace linalg {
template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
#if defined(XGBOOST_USE_HIP)
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
#else
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
#endif
{
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(t.DeviceIdx()));
#else
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
#endif
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
"For function with return, use transform instead.");
if (t.Contiguous()) {
@ -28,7 +38,12 @@ void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s
}
template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
#if defined(XGBOOST_USE_HIP)
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
#else
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
#endif
{
if (t.Contiguous()) {
auto ptr = t.Values().data();
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });