finish common.cu
This commit is contained in:
parent
8fd2af1c8b
commit
91a5ef762e
@ -8,7 +8,11 @@ namespace common {
|
|||||||
|
|
||||||
void SetDevice(std::int32_t device) {
|
void SetDevice(std::int32_t device) {
|
||||||
if (device >= 0) {
|
if (device >= 0) {
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -17,9 +21,17 @@ int AllVisibleGPUs() {
|
|||||||
try {
|
try {
|
||||||
// When compiled with CUDA but running on CPU only device,
|
// When compiled with CUDA but running on CPU only device,
|
||||||
// cudaGetDeviceCount will fail.
|
// cudaGetDeviceCount will fail.
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
|
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipGetDeviceCount(&n_visgpus));
|
||||||
|
#endif
|
||||||
} catch (const dmlc::Error &) {
|
} catch (const dmlc::Error &) {
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
cudaGetLastError(); // reset error.
|
cudaGetLastError(); // reset error.
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
hipGetLastError(); // reset error.
|
||||||
|
#endif
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
return n_visgpus;
|
return n_visgpus;
|
||||||
|
|||||||
@ -156,7 +156,7 @@ int AllVisibleGPUs();
|
|||||||
inline void AssertGPUSupport() {
|
inline void AssertGPUSupport() {
|
||||||
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||||
LOG(FATAL) << "XGBoost version not compiled with GPU support.";
|
LOG(FATAL) << "XGBoost version not compiled with GPU support.";
|
||||||
#endif // XGBOOST_USE_CUDA
|
#endif // XGBOOST_USE_CUDA && XGBOOST_USE_HIP
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void AssertOneAPISupport() {
|
inline void AssertOneAPISupport() {
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
#include "common.cu"
|
||||||
|
#endif
|
||||||
Loading…
x
Reference in New Issue
Block a user