port elementwise_metric.cu

This commit is contained in:
amdsc21 2023-03-08 21:39:56 +01:00
parent 7e1b06417b
commit 4c4e5af29c
3 changed files with 28 additions and 0 deletions

View File

@ -5,4 +5,6 @@
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
#include "elementwise_metric.cu" #include "elementwise_metric.cu"
#elif !defined(XGBOOST_USE_HIP)
#include "elementwise_metric.hip"
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)

View File

@ -29,6 +29,15 @@
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#endif // XGBOOST_USE_CUDA #endif // XGBOOST_USE_CUDA
#if defined(XGBOOST_USE_HIP)
#include <thrust/execution_policy.h> // thrust::hip::par
#include <thrust/functional.h> // thrust::plus<>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform_reduce.h>
#include "../common/device_helpers.hip.h"
#endif // XGBOOST_USE_HIP
namespace xgboost { namespace xgboost {
namespace metric { namespace metric {
// tag the this file, used by force static link later. // tag the this file, used by force static link later.
@ -84,6 +93,21 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
return PackedReduceResult{v, wt}; return PackedReduceResult{v, wt};
}, },
PackedReduceResult{}, thrust::plus<PackedReduceResult>()); PackedReduceResult{}, thrust::plus<PackedReduceResult>());
#elif defined(XGBOOST_USE_HIP)
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + labels.Size();
result = thrust::transform_reduce(
thrust::hip::par(alloc), begin, end,
[=] XGBOOST_DEVICE(size_t i) {
auto idx = linalg::UnravelIndex(i, labels.Shape());
auto sample_id = std::get<0>(idx);
auto target_id = std::get<1>(idx);
auto res = loss(i, sample_id, target_id);
float v{std::get<0>(res)}, wt{std::get<1>(res)};
return PackedReduceResult{v, wt};
},
PackedReduceResult{}, thrust::plus<PackedReduceResult>());
#else #else
common::AssertGPUSupport(); common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)

View File

@ -0,0 +1,2 @@
#include "elementwise_metric.cu"