Define the new device parameter. (#9362)
This commit is contained in:
@@ -326,7 +326,7 @@ object XGBoost extends Serializable {
|
||||
getGPUAddrFromResources
|
||||
}
|
||||
logger.info("Leveraging gpu device " + gpuId + " to train")
|
||||
params = params + ("gpu_id" -> gpuId)
|
||||
params = params + ("device" -> s"cuda:$gpuId")
|
||||
}
|
||||
val booster = if (makeCheckpoint) {
|
||||
SXGBoost.trainAndSaveCheckpoint(
|
||||
|
||||
Reference in New Issue
Block a user