finished survival_metric.cu

This commit is contained in:
amdsc21 2023-03-09 20:41:52 +01:00
parent b9d86d44d6
commit 4fd08b6c32
3 changed files with 38 additions and 6 deletions

View File

@ -6,6 +6,6 @@
*/ */
// Dummy file to keep the CUDA conditional compile trick. // Dummy file to keep the CUDA conditional compile trick.
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
#include "survival_metric.cu" #include "survival_metric.cu"
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)

View File

@ -24,6 +24,11 @@
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#endif // XGBOOST_USE_CUDA #endif // XGBOOST_USE_CUDA
#if defined(XGBOOST_USE_HIP)
#include <thrust/execution_policy.h> // thrust::hip::par
#include "../common/device_helpers.hip.h"
#endif // XGBOOST_USE_HIP
using AFTParam = xgboost::common::AFTParam; using AFTParam = xgboost::common::AFTParam;
using ProbabilityDistributionType = xgboost::common::ProbabilityDistributionType; using ProbabilityDistributionType = xgboost::common::ProbabilityDistributionType;
template <typename Distribution> template <typename Distribution>
@ -78,7 +83,7 @@ class ElementWiseSurvivalMetricsReduction {
return res; return res;
} }
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
PackedReduceResult DeviceReduceMetrics( PackedReduceResult DeviceReduceMetrics(
const HostDeviceVector<bst_float>& weights, const HostDeviceVector<bst_float>& weights,
@ -101,6 +106,8 @@ class ElementWiseSurvivalMetricsReduction {
auto d_policy = policy_; auto d_policy = policy_;
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_CUDA)
PackedReduceResult result = thrust::transform_reduce( PackedReduceResult result = thrust::transform_reduce(
thrust::cuda::par(alloc), thrust::cuda::par(alloc),
begin, end, begin, end,
@ -115,11 +122,27 @@ class ElementWiseSurvivalMetricsReduction {
}, },
PackedReduceResult(), PackedReduceResult(),
thrust::plus<PackedReduceResult>()); thrust::plus<PackedReduceResult>());
#elif defined(XGBOOST_USE_HIP)
PackedReduceResult result = thrust::transform_reduce(
thrust::hip::par(alloc),
begin, end,
[=] XGBOOST_DEVICE(size_t idx) {
double weight = is_null_weight ? 1.0 : static_cast<double>(s_weights[idx]);
double residue = d_policy.EvalRow(
static_cast<double>(s_label_lower_bound[idx]),
static_cast<double>(s_label_upper_bound[idx]),
static_cast<double>(s_preds[idx]));
residue *= weight;
return PackedReduceResult{residue, weight};
},
PackedReduceResult(),
thrust::plus<PackedReduceResult>());
#endif
return result; return result;
} }
#endif // XGBOOST_USE_CUDA #endif // XGBOOST_USE_CUDA || defined(XGBOOST_USE_HIP)
PackedReduceResult Reduce( PackedReduceResult Reduce(
const Context &ctx, const Context &ctx,
@ -133,17 +156,22 @@ class ElementWiseSurvivalMetricsReduction {
result = CpuReduceMetrics(weights, labels_lower_bound, labels_upper_bound, result = CpuReduceMetrics(weights, labels_lower_bound, labels_upper_bound,
preds, ctx.Threads()); preds, ctx.Threads());
} }
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
else { // NOLINT else { // NOLINT
preds.SetDevice(ctx.gpu_id); preds.SetDevice(ctx.gpu_id);
labels_lower_bound.SetDevice(ctx.gpu_id); labels_lower_bound.SetDevice(ctx.gpu_id);
labels_upper_bound.SetDevice(ctx.gpu_id); labels_upper_bound.SetDevice(ctx.gpu_id);
weights.SetDevice(ctx.gpu_id); weights.SetDevice(ctx.gpu_id);
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(ctx.gpu_id)); dh::safe_cuda(cudaSetDevice(ctx.gpu_id));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(ctx.gpu_id));
#endif
result = DeviceReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds); result = DeviceReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds);
} }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
return result; return result;
} }

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "survival_metric.cu"
#endif