[jvm-packages] set device to cuda when tree method is "gpu_hist" (#9412)

This commit is contained in:
Bobby Wang 2023-07-24 18:32:25 +08:00 committed by GitHub
parent a196443a07
commit 1b657a5513
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -180,10 +180,12 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
}
val device: Option[String] = overridedParams.get("device") match {
case None => None
case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev)
}
// back-compatible with "gpu_hist"
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
Some("cuda")
} else overridedParams.get("device").map(_.toString)
if (overridedParams.contains("train_test_ratio")) {
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +