Define the new device parameter. (#9362)

This commit is contained in:
Jiaming Yuan
2023-07-13 19:30:25 +08:00
committed by GitHub
parent 2d0cd2817e
commit 04aff3af8e
63 changed files with 827 additions and 477 deletions

View File

@@ -280,7 +280,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
// - gpu id
// - predictor: Force to gpu predictor since native doesn't save predictor.
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
booster.setParam("gpu_id", gpuId.toString)
booster.setParam("device", s"cuda:$gpuId")
logger.info("GPU transform on device: " + gpuId)
boosterFlag.isGpuParamsSet = true;
}