From bb6adda8a3ce7e150f7f282587fa5fce87f1bbf8 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Fri, 10 Mar 2023 05:12:51 +0100 Subject: [PATCH] finish c_api.cu --- src/c_api/c_api.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 61e6ca44e..89830b89b 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -50,13 +50,21 @@ void XGBBuildInfoDevice(Json *p_info) { void XGBoostAPIGuard::SetGPUAttribute() { // Not calling `safe_cuda` to avoid unnecessary exception handling overhead. // If errors, do nothing, assuming running on CPU only machine. +#if defined(XGBOOST_USE_CUDA) cudaGetDevice(&device_id_); +#elif defined(XGBOOST_USE_HIP) + hipGetDevice(&device_id_); +#endif } void XGBoostAPIGuard::RestoreGPUAttribute() { // Not calling `safe_cuda` to avoid unnecessary exception handling overhead. // If errors, do nothing, assuming running on CPU only machine. +#if defined(XGBOOST_USE_CUDA) cudaSetDevice(device_id_); +#elif defined(XGBOOST_USE_HIP) + hipSetDevice(device_id_); +#endif } } // namespace xgboost