Define the new device parameter. (#9362)

This commit is contained in:
Jiaming Yuan
2023-07-13 19:30:25 +08:00
committed by GitHub
parent 2d0cd2817e
commit 04aff3af8e
63 changed files with 827 additions and 477 deletions

View File

@@ -326,7 +326,7 @@ object XGBoost extends Serializable {
getGPUAddrFromResources
}
logger.info("Leveraging gpu device " + gpuId + " to train")
params = params + ("gpu_id" -> gpuId)
params = params + ("device" -> s"cuda:$gpuId")
}
val booster = if (makeCheckpoint) {
SXGBoost.trainAndSaveCheckpoint(