finish c_api.cu

This commit is contained in:
amdsc21 2023-03-10 05:12:51 +01:00
parent a76ccff390
commit bb6adda8a3

View File

@ -50,13 +50,21 @@ void XGBBuildInfoDevice(Json *p_info) {
void XGBoostAPIGuard::SetGPUAttribute() {
// Not calling `safe_cuda` to avoid unnecessary exception handling overhead.
// If errors, do nothing, assuming running on CPU only machine.
#if defined(XGBOOST_USE_CUDA)
cudaGetDevice(&device_id_);
#elif defined(XGBOOST_USE_HIP)
hipGetDevice(&device_id_);
#endif
}
void XGBoostAPIGuard::RestoreGPUAttribute() {
// Not calling `safe_cuda` to avoid unnecessary exception handling overhead.
// If errors, do nothing, assuming running on CPU only machine.
#if defined(XGBOOST_USE_CUDA)
cudaSetDevice(device_id_);
#elif defined(XGBOOST_USE_HIP)
hipSetDevice(device_id_);
#endif
}
} // namespace xgboost