diff --git a/src/metric/survival_metric.cc b/src/metric/survival_metric.cc index cf21a7fa2..34f0b461e 100644 --- a/src/metric/survival_metric.cc +++ b/src/metric/survival_metric.cc @@ -6,6 +6,6 @@ */ // Dummy file to keep the CUDA conditional compile trick. -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) #include "survival_metric.cu" -#endif // !defined(XGBOOST_USE_CUDA) +#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index 8205f07a1..6f17c6006 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -24,6 +24,11 @@ #include "../common/device_helpers.cuh" #endif // XGBOOST_USE_CUDA +#if defined(XGBOOST_USE_HIP) +#include // thrust::hip::par +#include "../common/device_helpers.hip.h" +#endif // XGBOOST_USE_HIP + using AFTParam = xgboost::common::AFTParam; using ProbabilityDistributionType = xgboost::common::ProbabilityDistributionType; template @@ -78,7 +83,7 @@ class ElementWiseSurvivalMetricsReduction { return res; } -#if defined(XGBOOST_USE_CUDA) +#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) PackedReduceResult DeviceReduceMetrics( const HostDeviceVector& weights, @@ -101,6 +106,8 @@ class ElementWiseSurvivalMetricsReduction { auto d_policy = policy_; dh::XGBCachingDeviceAllocator alloc; + +#if defined(XGBOOST_USE_CUDA) PackedReduceResult result = thrust::transform_reduce( thrust::cuda::par(alloc), begin, end, @@ -115,11 +122,27 @@ class ElementWiseSurvivalMetricsReduction { }, PackedReduceResult(), thrust::plus()); +#elif defined(XGBOOST_USE_HIP) + PackedReduceResult result = thrust::transform_reduce( + thrust::hip::par(alloc), + begin, end, + [=] XGBOOST_DEVICE(size_t idx) { + double weight = is_null_weight ? 1.0 : static_cast(s_weights[idx]); + double residue = d_policy.EvalRow( + static_cast(s_label_lower_bound[idx]), + static_cast(s_label_upper_bound[idx]), + static_cast(s_preds[idx])); + residue *= weight; + return PackedReduceResult{residue, weight}; + }, + PackedReduceResult(), + thrust::plus()); +#endif return result; } -#endif // XGBOOST_USE_CUDA +#endif // XGBOOST_USE_CUDA || defined(XGBOOST_USE_HIP) PackedReduceResult Reduce( const Context &ctx, @@ -133,17 +156,22 @@ class ElementWiseSurvivalMetricsReduction { result = CpuReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds, ctx.Threads()); } -#if defined(XGBOOST_USE_CUDA) +#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) else { // NOLINT preds.SetDevice(ctx.gpu_id); labels_lower_bound.SetDevice(ctx.gpu_id); labels_upper_bound.SetDevice(ctx.gpu_id); weights.SetDevice(ctx.gpu_id); +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(ctx.gpu_id)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(ctx.gpu_id)); +#endif + result = DeviceReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds); } -#endif // defined(XGBOOST_USE_CUDA) +#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) return result; } diff --git a/src/metric/survival_metric.hip b/src/metric/survival_metric.hip index e69de29bb..84a7d1ec2 100644 --- a/src/metric/survival_metric.hip +++ b/src/metric/survival_metric.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "survival_metric.cu" +#endif