From 6fa248b75fe2ca327aa612a60161e5e897132d0c Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 22:42:48 +0100 Subject: [PATCH] try elementwise_metric.cu --- src/common/common.h | 30 +++++++++++++++++------------- src/context.hip | 2 ++ src/gbm/gbtree.hip | 3 ++- src/metric/elementwise_metric.cc | 6 ++---- src/metric/elementwise_metric.hip | 2 ++ src/metric/multiclass_metric.cc | 2 +- src/metric/multiclass_metric.cu | 9 +++++++++ 7 files changed, 35 insertions(+), 19 deletions(-) diff --git a/src/common/common.h b/src/common/common.h index 867d08604..7ea15a54c 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -40,23 +40,12 @@ #endif // defined(__CUDACC__) namespace dh { -#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__) +#if defined(__CUDACC__) /* * Error handling functions */ #define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__) -#if defined(XGBOOST_USE_HIP) -inline hipError_t ThrowOnCudaError(hipError_t code, const char *file, int line) -{ - if (code != hipSuccess) { - LOG(FATAL) << thrust::system_error(code, thrust::hip_category(), - std::string{file} + ": " + // NOLINT - std::to_string(line)).what(); - } - return code; -} -#else inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, int line) { if (code != cudaSuccess) { @@ -66,8 +55,23 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, int line } return code; } + +#elif defined(__HIP_PLATFORM_AMD__) +/* + * Error handling functions + */ +#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__) + +inline hipError_t ThrowOnCudaError(hipError_t code, const char *file, int line) +{ + if (code != hipSuccess) { + LOG(FATAL) << thrust::system_error(code, thrust::hip_category(), + std::string{file} + ": " + // NOLINT + std::to_string(line)).what(); + } + return code; +} #endif -#endif // defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__) } // namespace dh namespace xgboost { diff --git a/src/context.hip b/src/context.hip index 487feeccb..d4e3938bf 100644 --- a/src/context.hip +++ b/src/context.hip @@ -1,2 +1,4 @@ +#if defined(XGBOOST_USE_HIP) #include "context.cu" +#endif diff --git a/src/gbm/gbtree.hip b/src/gbm/gbtree.hip index 21d362ece..76040e75f 100644 --- a/src/gbm/gbtree.hip +++ b/src/gbm/gbtree.hip @@ -1,3 +1,4 @@ +#if defined(XGBOOST_USE_HIP) #include "gbtree.cu" - +#endif diff --git a/src/metric/elementwise_metric.cc b/src/metric/elementwise_metric.cc index 848c66747..414177ab1 100644 --- a/src/metric/elementwise_metric.cc +++ b/src/metric/elementwise_metric.cc @@ -3,8 +3,6 @@ */ // Dummy file to keep the CUDA conditional compile trick. -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) #include "elementwise_metric.cu" -#elif !defined(XGBOOST_USE_HIP) -#include "elementwise_metric.hip" -#endif // !defined(XGBOOST_USE_CUDA) +#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) diff --git a/src/metric/elementwise_metric.hip b/src/metric/elementwise_metric.hip index 72b4f3e6e..18e4916a4 100644 --- a/src/metric/elementwise_metric.hip +++ b/src/metric/elementwise_metric.hip @@ -1,2 +1,4 @@ +#if defined(XGBOOST_USE_HIP) #include "elementwise_metric.cu" +#endif diff --git a/src/metric/multiclass_metric.cc b/src/metric/multiclass_metric.cc index 7733a334f..1257fb0fa 100644 --- a/src/metric/multiclass_metric.cc +++ b/src/metric/multiclass_metric.cc @@ -3,6 +3,6 @@ */ // Dummy file to keep the CUDA conditional compile trick. -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) #include "multiclass_metric.cu" #endif // !defined(XGBOOST_USE_CUDA) diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index aed6e7f4b..4e7c87048 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -23,6 +23,15 @@ #include "../common/device_helpers.cuh" #endif // XGBOOST_USE_CUDA +#if defined(XGBOOST_USE_HIP) +#include // thrust::cuda::par +#include // thrust::plus<> +#include +#include + +#include "../common/device_helpers.hip.h" +#endif // XGBOOST_USE_HIP + namespace xgboost { namespace metric { // tag the this file, used by force static link later.