Define the new device parameter. (#9362)
This commit is contained in:
@@ -280,7 +280,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
// - gpu id
|
||||
// - predictor: Force to gpu predictor since native doesn't save predictor.
|
||||
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
|
||||
booster.setParam("gpu_id", gpuId.toString)
|
||||
booster.setParam("device", s"cuda:$gpuId")
|
||||
logger.info("GPU transform on device: " + gpuId)
|
||||
boosterFlag.isGpuParamsSet = true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user