finish c_api.cu
This commit is contained in:
parent
a76ccff390
commit
bb6adda8a3
@ -50,13 +50,21 @@ void XGBBuildInfoDevice(Json *p_info) {
|
|||||||
void XGBoostAPIGuard::SetGPUAttribute() {
|
void XGBoostAPIGuard::SetGPUAttribute() {
|
||||||
// Not calling `safe_cuda` to avoid unnecessary exception handling overhead.
|
// Not calling `safe_cuda` to avoid unnecessary exception handling overhead.
|
||||||
// If errors, do nothing, assuming running on CPU only machine.
|
// If errors, do nothing, assuming running on CPU only machine.
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
cudaGetDevice(&device_id_);
|
cudaGetDevice(&device_id_);
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
hipGetDevice(&device_id_);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void XGBoostAPIGuard::RestoreGPUAttribute() {
|
void XGBoostAPIGuard::RestoreGPUAttribute() {
|
||||||
// Not calling `safe_cuda` to avoid unnecessary exception handling overhead.
|
// Not calling `safe_cuda` to avoid unnecessary exception handling overhead.
|
||||||
// If errors, do nothing, assuming running on CPU only machine.
|
// If errors, do nothing, assuming running on CPU only machine.
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
cudaSetDevice(device_id_);
|
cudaSetDevice(device_id_);
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
hipSetDevice(device_id_);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user