enable rocm, fix linalg_op.cuh
This commit is contained in:
parent
05fdca893f
commit
60795f22de
@ -12,8 +12,18 @@
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
|
||||
#else
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
|
||||
#endif
|
||||
{
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(t.DeviceIdx()));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
|
||||
#endif
|
||||
|
||||
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
|
||||
"For function with return, use transform instead.");
|
||||
if (t.Contiguous()) {
|
||||
@ -28,7 +38,12 @@ void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s
|
||||
}
|
||||
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
|
||||
#else
|
||||
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
|
||||
#endif
|
||||
{
|
||||
if (t.Contiguous()) {
|
||||
auto ptr = t.Values().data();
|
||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user