[jvm-packages] Add the new device parameter. (#9385)
This commit is contained in:
@@ -137,8 +137,12 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
|
||||
estimator match {
|
||||
case est: XGBoostEstimatorCommon =>
|
||||
require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
|
||||
s"GPU train requires tree_method set to gpu_hist")
|
||||
require(
|
||||
est.isDefined(est.device) &&
|
||||
(est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) ||
|
||||
est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
|
||||
s"GPU train requires `device` set to `cuda` or `gpu`."
|
||||
)
|
||||
val groupName = estimator match {
|
||||
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
|
||||
regressor.getGroupCol } else ""
|
||||
|
||||
Reference in New Issue
Block a user