From 312e58ec998a01dba41702458801b7421c2eed9c Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 06:45:03 +0100 Subject: [PATCH] enable rocm, fix common.h --- src/common/common.h | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/common/common.h b/src/common/common.h index 6ea342232..867d08604 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -46,8 +46,19 @@ namespace dh { */ #define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__) -inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, - int line) { +#if defined(XGBOOST_USE_HIP) +inline hipError_t ThrowOnCudaError(hipError_t code, const char *file, int line) +{ + if (code != hipSuccess) { + LOG(FATAL) << thrust::system_error(code, thrust::hip_category(), + std::string{file} + ": " + // NOLINT + std::to_string(line)).what(); + } + return code; +} +#else +inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, int line) +{ if (code != cudaSuccess) { LOG(FATAL) << thrust::system_error(code, thrust::cuda_category(), std::string{file} + ": " + // NOLINT @@ -55,6 +66,7 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, } return code; } +#endif #endif // defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__) } // namespace dh