finish data.cu
This commit is contained in:
@@ -4,7 +4,12 @@
|
||||
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
#define XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include "device_helpers.cuh"
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
#include "device_helpers.hip.h"
|
||||
#endif
|
||||
|
||||
#include "linalg_op.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/linalg.h"
|
||||
@@ -14,13 +19,13 @@ namespace linalg {
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
|
||||
#else
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
|
||||
#endif
|
||||
{
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(t.DeviceIdx()));
|
||||
#else
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
|
||||
#endif
|
||||
|
||||
@@ -40,7 +45,7 @@ void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
|
||||
#else
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
|
||||
#endif
|
||||
{
|
||||
|
||||
@@ -42,7 +42,7 @@ void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& f
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) {
|
||||
common::AssertGPUSupport();
|
||||
@@ -60,7 +60,7 @@ void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn)
|
||||
}
|
||||
ElementWiseKernelHost(t, ctx->Threads(), fn);
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_
|
||||
|
||||
template <typename T, std::int32_t kDim>
|
||||
auto cbegin(TensorView<T, kDim> const& v) { // NOLINT
|
||||
|
||||
Reference in New Issue
Block a user