enable rocm, fix linalg_op.cuh
This commit is contained in:
parent
05fdca893f
commit
60795f22de
@ -12,8 +12,18 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace linalg {
|
namespace linalg {
|
||||||
template <typename T, int32_t D, typename Fn>
|
template <typename T, int32_t D, typename Fn>
|
||||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
|
||||||
|
#else
|
||||||
|
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(t.DeviceIdx()));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
|
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
|
||||||
|
#endif
|
||||||
|
|
||||||
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
|
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
|
||||||
"For function with return, use transform instead.");
|
"For function with return, use transform instead.");
|
||||||
if (t.Contiguous()) {
|
if (t.Contiguous()) {
|
||||||
@ -28,7 +38,12 @@ void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int32_t D, typename Fn>
|
template <typename T, int32_t D, typename Fn>
|
||||||
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
|
||||||
|
#else
|
||||||
|
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
|
||||||
|
#endif
|
||||||
|
{
|
||||||
if (t.Contiguous()) {
|
if (t.Contiguous()) {
|
||||||
auto ptr = t.Values().data();
|
auto ptr = t.Values().data();
|
||||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
|
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user