From 91a5ef762e2df8a231f51209b851f9a8d0a15c14 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Fri, 10 Mar 2023 05:19:41 +0100 Subject: [PATCH] finish common.cu --- src/common/common.cu | 12 ++++++++++++ src/common/common.h | 2 +- src/common/common.hip | 4 ++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/common/common.cu b/src/common/common.cu index b6965904a..0997b7c83 100644 --- a/src/common/common.cu +++ b/src/common/common.cu @@ -8,7 +8,11 @@ namespace common { void SetDevice(std::int32_t device) { if (device >= 0) { +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(device)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device)); +#endif } } @@ -17,9 +21,17 @@ int AllVisibleGPUs() { try { // When compiled with CUDA but running on CPU only device, // cudaGetDeviceCount will fail. +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaGetDeviceCount(&n_visgpus)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipGetDeviceCount(&n_visgpus)); +#endif } catch (const dmlc::Error &) { +#if defined(XGBOOST_USE_CUDA) cudaGetLastError(); // reset error. +#elif defined(XGBOOST_USE_HIP) + hipGetLastError(); // reset error. +#endif return 0; } return n_visgpus; diff --git a/src/common/common.h b/src/common/common.h index 9d1f1e48a..04482a107 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -156,7 +156,7 @@ int AllVisibleGPUs(); inline void AssertGPUSupport() { #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) LOG(FATAL) << "XGBoost version not compiled with GPU support."; -#endif // XGBOOST_USE_CUDA +#endif // XGBOOST_USE_CUDA && XGBOOST_USE_HIP } inline void AssertOneAPISupport() { diff --git a/src/common/common.hip b/src/common/common.hip index e69de29bb..c665b11bc 100644 --- a/src/common/common.hip +++ b/src/common/common.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "common.cu" +#endif