enable rocm, fix common.h

This commit is contained in:
amdsc21 2023-03-08 06:45:03 +01:00
parent ca8f4e7993
commit 312e58ec99

View File

@ -46,8 +46,19 @@ namespace dh {
*/ */
#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__) #define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__)
inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, #if defined(XGBOOST_USE_HIP)
int line) { inline hipError_t ThrowOnCudaError(hipError_t code, const char *file, int line)
{
if (code != hipSuccess) {
LOG(FATAL) << thrust::system_error(code, thrust::hip_category(),
std::string{file} + ": " + // NOLINT
std::to_string(line)).what();
}
return code;
}
#else
inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, int line)
{
if (code != cudaSuccess) { if (code != cudaSuccess) {
LOG(FATAL) << thrust::system_error(code, thrust::cuda_category(), LOG(FATAL) << thrust::system_error(code, thrust::cuda_category(),
std::string{file} + ": " + // NOLINT std::string{file} + ": " + // NOLINT
@ -55,6 +66,7 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file,
} }
return code; return code;
} }
#endif
#endif // defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__) #endif // defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
} // namespace dh } // namespace dh