add HIP flags, common
This commit is contained in:
@@ -27,6 +27,12 @@
|
||||
|
||||
#define WITH_CUDA() true
|
||||
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
#include <thrust/system/hip/error.h>
|
||||
#include <thrust/system_error.h>
|
||||
|
||||
#define WITH_CUDA() true
|
||||
|
||||
#else
|
||||
|
||||
#define WITH_CUDA() false
|
||||
@@ -34,7 +40,7 @@
|
||||
#endif // defined(__CUDACC__)
|
||||
|
||||
namespace dh {
|
||||
#if defined(__CUDACC__)
|
||||
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
|
||||
/*
|
||||
* Error handling functions
|
||||
*/
|
||||
@@ -49,7 +55,7 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file,
|
||||
}
|
||||
return code;
|
||||
}
|
||||
#endif // defined(__CUDACC__)
|
||||
#endif // defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
|
||||
} // namespace dh
|
||||
|
||||
namespace xgboost {
|
||||
@@ -167,7 +173,7 @@ class Range {
|
||||
int AllVisibleGPUs();
|
||||
|
||||
inline void AssertGPUSupport() {
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||
LOG(FATAL) << "XGBoost version not compiled with GPU support.";
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
}
|
||||
@@ -180,7 +186,7 @@ inline void AssertOneAPISupport() {
|
||||
|
||||
void SetDevice(std::int32_t device);
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||
inline void SetDevice(std::int32_t device) {
|
||||
if (device >= 0) {
|
||||
AssertGPUSupport();
|
||||
|
||||
Reference in New Issue
Block a user