From 4c4e5af29cc7a7fba92948a450a495b0435781fd Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 21:39:56 +0100 Subject: [PATCH] port elementwise_metric.cu --- src/metric/elementwise_metric.cc | 2 ++ src/metric/elementwise_metric.cu | 24 ++++++++++++++++++++++++ src/metric/elementwise_metric.hip | 2 ++ 3 files changed, 28 insertions(+) diff --git a/src/metric/elementwise_metric.cc b/src/metric/elementwise_metric.cc index 0a3e673c1..848c66747 100644 --- a/src/metric/elementwise_metric.cc +++ b/src/metric/elementwise_metric.cc @@ -5,4 +5,6 @@ #if !defined(XGBOOST_USE_CUDA) #include "elementwise_metric.cu" +#elif !defined(XGBOOST_USE_HIP) +#include "elementwise_metric.hip" #endif // !defined(XGBOOST_USE_CUDA) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 9006bdfca..aab1e7a95 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -29,6 +29,15 @@ #include "../common/device_helpers.cuh" #endif // XGBOOST_USE_CUDA +#if defined(XGBOOST_USE_HIP) +#include // thrust::hip::par +#include // thrust::plus<> +#include +#include + +#include "../common/device_helpers.hip.h" +#endif // XGBOOST_USE_HIP + namespace xgboost { namespace metric { // tag the this file, used by force static link later. @@ -84,6 +93,21 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) { return PackedReduceResult{v, wt}; }, PackedReduceResult{}, thrust::plus()); +#elif defined(XGBOOST_USE_HIP) + dh::XGBCachingDeviceAllocator alloc; + thrust::counting_iterator begin(0); + thrust::counting_iterator end = begin + labels.Size(); + result = thrust::transform_reduce( + thrust::hip::par(alloc), begin, end, + [=] XGBOOST_DEVICE(size_t i) { + auto idx = linalg::UnravelIndex(i, labels.Shape()); + auto sample_id = std::get<0>(idx); + auto target_id = std::get<1>(idx); + auto res = loss(i, sample_id, target_id); + float v{std::get<0>(res)}, wt{std::get<1>(res)}; + return PackedReduceResult{v, wt}; + }, + PackedReduceResult{}, thrust::plus()); #else common::AssertGPUSupport(); #endif // defined(XGBOOST_USE_CUDA) diff --git a/src/metric/elementwise_metric.hip b/src/metric/elementwise_metric.hip index e69de29bb..72b4f3e6e 100644 --- a/src/metric/elementwise_metric.hip +++ b/src/metric/elementwise_metric.hip @@ -0,0 +1,2 @@ + +#include "elementwise_metric.cu"