finish c_api.cu

This commit is contained in:
amdsc21 2023-03-10 05:12:51 +01:00
parent a76ccff390
commit bb6adda8a3

View File

@ -50,13 +50,21 @@ void XGBBuildInfoDevice(Json *p_info) {
void XGBoostAPIGuard::SetGPUAttribute() { void XGBoostAPIGuard::SetGPUAttribute() {
// Not calling `safe_cuda` to avoid unnecessary exception handling overhead. // Not calling `safe_cuda` to avoid unnecessary exception handling overhead.
// If errors, do nothing, assuming running on CPU only machine. // If errors, do nothing, assuming running on CPU only machine.
#if defined(XGBOOST_USE_CUDA)
cudaGetDevice(&device_id_); cudaGetDevice(&device_id_);
#elif defined(XGBOOST_USE_HIP)
hipGetDevice(&device_id_);
#endif
} }
void XGBoostAPIGuard::RestoreGPUAttribute() { void XGBoostAPIGuard::RestoreGPUAttribute() {
// Not calling `safe_cuda` to avoid unnecessary exception handling overhead. // Not calling `safe_cuda` to avoid unnecessary exception handling overhead.
// If errors, do nothing, assuming running on CPU only machine. // If errors, do nothing, assuming running on CPU only machine.
#if defined(XGBOOST_USE_CUDA)
cudaSetDevice(device_id_); cudaSetDevice(device_id_);
#elif defined(XGBOOST_USE_HIP)
hipSetDevice(device_id_);
#endif
} }
} // namespace xgboost } // namespace xgboost