From ccce4cf7e1dd5cf6441c4adc8c3473cdb6b0bf93 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Fri, 10 Mar 2023 05:00:57 +0100 Subject: [PATCH] finish data.cu --- src/common/linalg_op.cuh | 11 ++++++++--- src/common/linalg_op.h | 4 ++-- src/data/data.cc | 4 ++-- src/data/data.cu | 34 +++++++++++++++++++++++++++++++++ src/data/data.hip | 4 ++++ src/objective/quantile_obj.cc | 4 ++-- src/objective/quantile_obj.cu | 10 +++++----- src/objective/regression_obj.cc | 4 ++-- 8 files changed, 59 insertions(+), 16 deletions(-) diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index 941de49c5..fdd72df75 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -4,7 +4,12 @@ #ifndef XGBOOST_COMMON_LINALG_OP_CUH_ #define XGBOOST_COMMON_LINALG_OP_CUH_ +#if defined(XGBOOST_USE_CUDA) #include "device_helpers.cuh" +#elif defined(XGBOOST_USE_HIP) +#include "device_helpers.hip.h" +#endif + #include "linalg_op.h" #include "xgboost/context.h" #include "xgboost/linalg.h" @@ -14,13 +19,13 @@ namespace linalg { template #if defined(XGBOOST_USE_HIP) void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, hipStream_t s = nullptr) -#else +#elif defined(XGBOOST_USE_CUDA) void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) #endif { #if defined(XGBOOST_USE_HIP) dh::safe_cuda(hipSetDevice(t.DeviceIdx())); -#else +#elif defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(t.DeviceIdx())); #endif @@ -40,7 +45,7 @@ void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s template #if defined(XGBOOST_USE_HIP) void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, hipStream_t s = nullptr) -#else +#elif defined(XGBOOST_USE_CUDA) void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) #endif { diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index f55927402..7e908135c 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -42,7 +42,7 @@ void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& f } } -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) template void ElementWiseKernelDevice(linalg::TensorView, Fn&&, void* = nullptr) { common::AssertGPUSupport(); @@ -60,7 +60,7 @@ void ElementWiseKernel(Context const* ctx, linalg::TensorView t, Fn&& fn) } ElementWiseKernelHost(t, ctx->Threads(), fn); } -#endif // !defined(XGBOOST_USE_CUDA) +#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_ template auto cbegin(TensorView const& v) { // NOLINT diff --git a/src/data/data.cc b/src/data/data.cc index d24048a2a..b61534ce4 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -755,9 +755,9 @@ void MetaInfo::Validate(std::int32_t device) const { } } -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); } -#endif // !defined(XGBOOST_USE_CUDA) +#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) using DMatrixThreadLocal = dmlc::ThreadLocalStore>; diff --git a/src/data/data.cu b/src/data/data.cu index 4dedc7d24..7854ccd3f 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -5,7 +5,13 @@ * \brief Handles setting metainfo from array interface. */ #include "../common/cuda_context.cuh" + +#if defined(XGBOOST_USE_CUDA) #include "../common/device_helpers.cuh" +#elif defined(XGBOOST_USE_HIP) +#include "../common/device_helpers.hip.h" +#endif + #include "../common/linalg_op.cuh" #include "array_interface.h" #include "device_adapter.cuh" @@ -15,14 +21,22 @@ #include "xgboost/json.h" #include "xgboost/logging.h" +#if defined(XGBOOST_USE_HIP) +namespace cub = hipcub; +#endif + namespace xgboost { namespace { auto SetDeviceToPtr(void const* ptr) { +#if defined(XGBOOST_USE_CUDA) cudaPointerAttributes attr; dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); int32_t ptr_device = attr.device; dh::safe_cuda(cudaSetDevice(ptr_device)); return ptr_device; +#elif defined(XGBOOST_USE_HIP) /* this is wrong, need to figure out */ + return 0; +#endif } template @@ -43,8 +57,14 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens std::copy(array.shape, array.shape + D, shape.data()); // set data data->Resize(array.n); + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T), cudaMemcpyDefault, ctx->Stream())); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T), + hipMemcpyDefault, ctx->Stream())); +#endif }); return; } @@ -94,8 +114,15 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector* p_ } }); bool non_dec = true; + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool), cudaMemcpyDeviceToHost)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpy(&non_dec, flag.data().get(), sizeof(bool), + hipMemcpyDeviceToHost)); +#endif + CHECK(non_dec) << "`qid` must be sorted in increasing order along with data."; size_t bytes = 0; dh::caching_device_vector out(array_interface.Shape(0)); @@ -113,8 +140,15 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector* p_ group_ptr_.clear(); group_ptr_.resize(h_num_runs_out + 1, 0); dh::XGBCachingDeviceAllocator alloc; + +#if defined(XGBOOST_USE_CUDA) thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(), cnt.begin() + h_num_runs_out, cnt.begin()); +#elif defined(XGBOOST_USE_HIP) + thrust::inclusive_scan(thrust::hip::par(alloc), cnt.begin(), + cnt.begin() + h_num_runs_out, cnt.begin()); +#endif + thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out, group_ptr_.begin() + 1); } diff --git a/src/data/data.hip b/src/data/data.hip index e69de29bb..a0b80a7e0 100644 --- a/src/data/data.hip +++ b/src/data/data.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "data.cu" +#endif diff --git a/src/objective/quantile_obj.cc b/src/objective/quantile_obj.cc index 89e2d6010..0316b0cc8 100644 --- a/src/objective/quantile_obj.cc +++ b/src/objective/quantile_obj.cc @@ -13,6 +13,6 @@ DMLC_REGISTRY_FILE_TAG(quantile_obj); } // namespace obj } // namespace xgboost -#ifndef XGBOOST_USE_CUDA +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) #include "quantile_obj.cu" -#endif // !defined(XBGOOST_USE_CUDA) +#endif // !defined(XBGOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index 0a40758bc..5b404692b 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -19,7 +19,7 @@ #include "xgboost/objective.h" // ObjFunction #include "xgboost/parameter.h" // XGBoostParameter -#if defined(XGBOOST_USE_CUDA) +#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) #include "../common/linalg_op.cuh" // ElementWiseKernel #include "../common/stats.cuh" // SegmentedQuantile @@ -123,7 +123,7 @@ class QuantileRegression : public ObjFunction { } } } else { -#if defined(XGBOOST_USE_CUDA) +#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) alpha_.SetDevice(ctx_->gpu_id); auto d_alpha = alpha_.ConstDeviceSpan(); auto d_labels = info.labels.View(ctx_->gpu_id); @@ -158,7 +158,7 @@ class QuantileRegression : public ObjFunction { } #else common::AssertGPUSupport(); -#endif // defined(XGBOOST_USE_CUDA) +#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) } // For multiple quantiles, we should extend the base score to a vector instead of @@ -215,8 +215,8 @@ XGBOOST_REGISTER_OBJECTIVE(QuantileRegression, QuantileRegression::Name()) .describe("Regression with quantile loss.") .set_body([]() { return new QuantileRegression(); }); -#if defined(XGBOOST_USE_CUDA) +#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu); -#endif // defined(XGBOOST_USE_CUDA) +#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) } // namespace obj } // namespace xgboost diff --git a/src/objective/regression_obj.cc b/src/objective/regression_obj.cc index 663989fbd..99bd200ab 100644 --- a/src/objective/regression_obj.cc +++ b/src/objective/regression_obj.cc @@ -13,6 +13,6 @@ DMLC_REGISTRY_FILE_TAG(regression_obj); } // namespace obj } // namespace xgboost -#ifndef XGBOOST_USE_CUDA +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) #include "regression_obj.cu" -#endif // XGBOOST_USE_CUDA +#endif // XGBOOST_USE_CUDA && defined(XGBOOST_USE_HIP)