From 344f90b67ba0966c04ef05321eb99c127c3c2552 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 14 Aug 2023 17:52:14 +0800 Subject: [PATCH] [jvm-packages] throw exception when tree_method=approx and device=cuda (#9478) --------- Co-authored-by: Jiaming Yuan --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 76 ++++++++++--------- .../spark/params/LearningTaskParams.scala | 2 + .../scala/spark/ParameterSuite.scala | 11 +++ 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 7bb245035..d12431479 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -93,12 +93,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s private val overridedParams = overrideParams(rawParams, sc) + validateSparkSslConf() + /** * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true). * If so, throw an exception unless this safety measure has been explicitly overridden * via conf `xgboost.spark.ignoreSsl`. */ - private def validateSparkSslConf: Unit = { + private def validateSparkSslConf(): Unit = { val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) = SparkSession.getActiveSession match { case Some(ss) => @@ -148,55 +150,59 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s overridedParams } + /** + * The Map parameters accepted by estimator's constructor may have string type, + * Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these + * kind of parameters into the correct type in the function. + * + * @return XGBoostExecutionParams + */ def buildXGBRuntimeParams: XGBoostExecutionParams = { - val nWorkers = overridedParams("num_workers").asInstanceOf[Int] - val round = overridedParams("num_round").asInstanceOf[Int] - val useExternalMemory = overridedParams - .getOrElse("use_external_memory", false).asInstanceOf[Boolean] + val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait] val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait] - val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float] - val allowNonZeroForMissing = overridedParams - .getOrElse("allow_non_zero_for_missing", false) - .asInstanceOf[Boolean] - validateSparkSslConf - var treeMethod: Option[String] = None - if (overridedParams.contains("tree_method")) { - require(overridedParams("tree_method") == "hist" || - overridedParams("tree_method") == "approx" || - overridedParams("tree_method") == "auto" || - overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" + - " as 'hist', 'approx', 'gpu_hist', and 'auto'") - treeMethod = Some(overridedParams("tree_method").asInstanceOf[String]) - } - - // back-compatible with "gpu_hist" - val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) { - Some("cuda") - } else overridedParams.get("device").map(_.toString) - - if (overridedParams.contains("train_test_ratio")) { - logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + - " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + - "'eval_set_names'") - } - require(nWorkers > 0, "you must specify more than 0 workers") if (obj != null) { require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " + "is not defined, you have to specify the objective type as classification or regression" + " with a customized objective function") } + + var trainTestRatio = 1.0 + if (overridedParams.contains("train_test_ratio")) { + logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + + " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + + "'eval_set_names'") + trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double] + } + + val nWorkers = overridedParams("num_workers").asInstanceOf[Int] + val round = overridedParams("num_round").asInstanceOf[Int] + val useExternalMemory = overridedParams + .getOrElse("use_external_memory", false).asInstanceOf[Boolean] + + val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float] + val allowNonZeroForMissing = overridedParams + .getOrElse("allow_non_zero_for_missing", false) + .asInstanceOf[Boolean] + + val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString) + // back-compatible with "gpu_hist" + val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) { + Some("cuda") + } else overridedParams.get("device").map(_.toString) + + require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")), + "The tree method \"approx\" is not yet supported for Spark GPU cluster") + val trackerConf = overridedParams.get("tracker_conf") match { case None => TrackerConf() case Some(conf: TrackerConf) => conf case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " + "instance of TrackerConf.") } - val checkpointParam = - ExternalCheckpointParams.extractParams(overridedParams) - val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0) - .asInstanceOf[Double] + val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams) + val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long] val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index bcbd7548f..b73e6cbaa 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -68,11 +68,13 @@ private[spark] trait LearningTaskParams extends Params { /** * Fraction of training points to use for testing. */ + @Deprecated final val trainTestRatio = new DoubleParam(this, "trainTestRatio", "fraction of training points to use for testing", ParamValidators.inRange(0, 1)) setDefault(trainTestRatio, 1.0) + @Deprecated final def getTrainTestRatio: Double = $(trainTestRatio) /** diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index 11b60e74d..f187f7394 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -92,4 +92,15 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll { classifier.getBaseScore } } + + test("approx can't be used for gpu train") { + val paramMap = Map("tree_method" -> "approx", "device" -> "cuda") + val trainingDF = buildDataFrame(MultiClassification.train) + val xgb = new XGBoostClassifier(paramMap) + val thrown = intercept[IllegalArgumentException] { + xgb.fit(trainingDF) + } + assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " + + "for Spark GPU cluster")) + } }