[jvm-packages] set device to cuda when tree method is "gpu_hist" (#9412)
This commit is contained in:
parent
a196443a07
commit
1b657a5513
@ -180,10 +180,12 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
||||||
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
||||||
}
|
}
|
||||||
val device: Option[String] = overridedParams.get("device") match {
|
|
||||||
case None => None
|
// back-compatible with "gpu_hist"
|
||||||
case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev)
|
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
|
||||||
}
|
Some("cuda")
|
||||||
|
} else overridedParams.get("device").map(_.toString)
|
||||||
|
|
||||||
if (overridedParams.contains("train_test_ratio")) {
|
if (overridedParams.contains("train_test_ratio")) {
|
||||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user