finish common.cu
This commit is contained in:
parent
8fd2af1c8b
commit
91a5ef762e
@ -8,7 +8,11 @@ namespace common {
|
||||
|
||||
void SetDevice(std::int32_t device) {
|
||||
if (device >= 0) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,9 +21,17 @@ int AllVisibleGPUs() {
|
||||
try {
|
||||
// When compiled with CUDA but running on CPU only device,
|
||||
// cudaGetDeviceCount will fail.
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipGetDeviceCount(&n_visgpus));
|
||||
#endif
|
||||
} catch (const dmlc::Error &) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
cudaGetLastError(); // reset error.
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
hipGetLastError(); // reset error.
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
return n_visgpus;
|
||||
|
||||
@ -156,7 +156,7 @@ int AllVisibleGPUs();
|
||||
inline void AssertGPUSupport() {
|
||||
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||
LOG(FATAL) << "XGBoost version not compiled with GPU support.";
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
#endif // XGBOOST_USE_CUDA && XGBOOST_USE_HIP
|
||||
}
|
||||
|
||||
inline void AssertOneAPISupport() {
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
#include "common.cu"
|
||||
#endif
|
||||
Loading…
x
Reference in New Issue
Block a user