finish data.cu

This commit is contained in:
amdsc21
2023-03-10 05:00:57 +01:00
parent 713ab9e1a0
commit ccce4cf7e1
8 changed files with 59 additions and 16 deletions

View File

@@ -4,7 +4,12 @@
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_
#define XGBOOST_COMMON_LINALG_OP_CUH_
#if defined(XGBOOST_USE_CUDA)
#include "device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "device_helpers.hip.h"
#endif
#include "linalg_op.h"
#include "xgboost/context.h"
#include "xgboost/linalg.h"
@@ -14,13 +19,13 @@ namespace linalg {
template <typename T, int32_t D, typename Fn>
#if defined(XGBOOST_USE_HIP)
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
#else
#elif defined(XGBOOST_USE_CUDA)
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
#endif
{
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(t.DeviceIdx()));
#else
#elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
#endif
@@ -40,7 +45,7 @@ void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s
template <typename T, int32_t D, typename Fn>
#if defined(XGBOOST_USE_HIP)
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
#else
#elif defined(XGBOOST_USE_CUDA)
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
#endif
{

View File

@@ -42,7 +42,7 @@ void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& f
}
}
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) {
common::AssertGPUSupport();
@@ -60,7 +60,7 @@ void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn)
}
ElementWiseKernelHost(t, ctx->Threads(), fn);
}
#endif // !defined(XGBOOST_USE_CUDA)
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_
template <typename T, std::int32_t kDim>
auto cbegin(TensorView<T, kDim> const& v) { // NOLINT