[jvm-packages] Add the new device parameter. (#9385)

This commit is contained in:
Jiaming Yuan
2023-07-17 18:40:39 +08:00
committed by GitHub
parent 2caceb157d
commit f4fb2be101
15 changed files with 112 additions and 47 deletions

View File

@@ -137,8 +137,12 @@ object GpuPreXGBoost extends PreXGBoostProvider {
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
estimator match {
case est: XGBoostEstimatorCommon =>
require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
s"GPU train requires tree_method set to gpu_hist")
require(
est.isDefined(est.device) &&
(est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) ||
est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
s"GPU train requires `device` set to `cuda` or `gpu`."
)
val groupName = estimator match {
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
regressor.getGroupCol } else ""