add cuda to hip wrapper

This commit is contained in:
Your Name
2023-10-17 12:42:37 -07:00
parent ea19555474
commit ffbbc9c968
35 changed files with 60 additions and 509 deletions

View File

@@ -159,11 +159,7 @@ class ElementWiseSurvivalMetricsReduction {
labels_upper_bound.SetDevice(ctx.gpu_id);
weights.SetDevice(ctx.gpu_id);
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(ctx.gpu_id));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(ctx.gpu_id));
#endif
result = DeviceReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds);
}