add HIP flags, common

This commit is contained in:
amdsc21
2023-03-08 03:11:49 +01:00
parent 1e1c7fd8d5
commit 840f15209c
10 changed files with 44 additions and 33 deletions

View File

@@ -27,6 +27,12 @@
#define WITH_CUDA() true
#elif defined(__HIP_PLATFORM_AMD__)
#include <thrust/system/hip/error.h>
#include <thrust/system_error.h>
#define WITH_CUDA() true
#else
#define WITH_CUDA() false
@@ -34,7 +40,7 @@
#endif // defined(__CUDACC__)
namespace dh {
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
/*
* Error handling functions
*/
@@ -49,7 +55,7 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file,
}
return code;
}
#endif // defined(__CUDACC__)
#endif // defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
} // namespace dh
namespace xgboost {
@@ -167,7 +173,7 @@ class Range {
int AllVisibleGPUs();
inline void AssertGPUSupport() {
#ifndef XGBOOST_USE_CUDA
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
LOG(FATAL) << "XGBoost version not compiled with GPU support.";
#endif // XGBOOST_USE_CUDA
}
@@ -180,7 +186,7 @@ inline void AssertOneAPISupport() {
void SetDevice(std::int32_t device);
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
inline void SetDevice(std::int32_t device) {
if (device >= 0) {
AssertGPUSupport();