finish common.cu

This commit is contained in:
amdsc21 2023-03-10 05:19:41 +01:00
parent 8fd2af1c8b
commit 91a5ef762e
3 changed files with 17 additions and 1 deletions

View File

@ -8,7 +8,11 @@ namespace common {
void SetDevice(std::int32_t device) {
if (device >= 0) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device));
#endif
}
}
@ -17,9 +21,17 @@ int AllVisibleGPUs() {
try {
// When compiled with CUDA but running on CPU only device,
// cudaGetDeviceCount will fail.
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipGetDeviceCount(&n_visgpus));
#endif
} catch (const dmlc::Error &) {
#if defined(XGBOOST_USE_CUDA)
cudaGetLastError(); // reset error.
#elif defined(XGBOOST_USE_HIP)
hipGetLastError(); // reset error.
#endif
return 0;
}
return n_visgpus;

View File

@ -156,7 +156,7 @@ int AllVisibleGPUs();
inline void AssertGPUSupport() {
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
LOG(FATAL) << "XGBoost version not compiled with GPU support.";
#endif // XGBOOST_USE_CUDA
#endif // XGBOOST_USE_CUDA && XGBOOST_USE_HIP
}
inline void AssertOneAPISupport() {

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "common.cu"
#endif