[jvm-packages] throw exception when tree_method=approx and device=cuda (#9478)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
05d7000096
commit
344f90b67b
@ -93,12 +93,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
|
||||
private val overridedParams = overrideParams(rawParams, sc)
|
||||
|
||||
validateSparkSslConf()
|
||||
|
||||
/**
|
||||
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||
* via conf `xgboost.spark.ignoreSsl`.
|
||||
*/
|
||||
private def validateSparkSslConf: Unit = {
|
||||
private def validateSparkSslConf(): Unit = {
|
||||
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
||||
SparkSession.getActiveSession match {
|
||||
case Some(ss) =>
|
||||
@ -148,55 +150,59 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
overridedParams
|
||||
}
|
||||
|
||||
/**
|
||||
* The Map parameters accepted by estimator's constructor may have string type,
|
||||
* Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
|
||||
* kind of parameters into the correct type in the function.
|
||||
*
|
||||
* @return XGBoostExecutionParams
|
||||
*/
|
||||
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||
val useExternalMemory = overridedParams
|
||||
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
|
||||
|
||||
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||
val allowNonZeroForMissing = overridedParams
|
||||
.getOrElse("allow_non_zero_for_missing", false)
|
||||
.asInstanceOf[Boolean]
|
||||
validateSparkSslConf
|
||||
var treeMethod: Option[String] = None
|
||||
if (overridedParams.contains("tree_method")) {
|
||||
require(overridedParams("tree_method") == "hist" ||
|
||||
overridedParams("tree_method") == "approx" ||
|
||||
overridedParams("tree_method") == "auto" ||
|
||||
overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" +
|
||||
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
||||
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
||||
}
|
||||
|
||||
// back-compatible with "gpu_hist"
|
||||
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
|
||||
Some("cuda")
|
||||
} else overridedParams.get("device").map(_.toString)
|
||||
|
||||
if (overridedParams.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||
"'eval_set_names'")
|
||||
}
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
if (obj != null) {
|
||||
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
|
||||
"is not defined, you have to specify the objective type as classification or regression" +
|
||||
" with a customized objective function")
|
||||
}
|
||||
|
||||
var trainTestRatio = 1.0
|
||||
if (overridedParams.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||
"'eval_set_names'")
|
||||
trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double]
|
||||
}
|
||||
|
||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||
val useExternalMemory = overridedParams
|
||||
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
|
||||
|
||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||
val allowNonZeroForMissing = overridedParams
|
||||
.getOrElse("allow_non_zero_for_missing", false)
|
||||
.asInstanceOf[Boolean]
|
||||
|
||||
val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
|
||||
// back-compatible with "gpu_hist"
|
||||
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
|
||||
Some("cuda")
|
||||
} else overridedParams.get("device").map(_.toString)
|
||||
|
||||
require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")),
|
||||
"The tree method \"approx\" is not yet supported for Spark GPU cluster")
|
||||
|
||||
val trackerConf = overridedParams.get("tracker_conf") match {
|
||||
case None => TrackerConf()
|
||||
case Some(conf: TrackerConf) => conf
|
||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||
"instance of TrackerConf.")
|
||||
}
|
||||
val checkpointParam =
|
||||
ExternalCheckpointParams.extractParams(overridedParams)
|
||||
|
||||
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
||||
.asInstanceOf[Double]
|
||||
val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams)
|
||||
|
||||
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
|
||||
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
|
||||
|
||||
|
||||
@ -68,11 +68,13 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
/**
|
||||
* Fraction of training points to use for testing.
|
||||
*/
|
||||
@Deprecated
|
||||
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||
"fraction of training points to use for testing",
|
||||
ParamValidators.inRange(0, 1))
|
||||
setDefault(trainTestRatio, 1.0)
|
||||
|
||||
@Deprecated
|
||||
final def getTrainTestRatio: Double = $(trainTestRatio)
|
||||
|
||||
/**
|
||||
|
||||
@ -92,4 +92,15 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
classifier.getBaseScore
|
||||
}
|
||||
}
|
||||
|
||||
test("approx can't be used for gpu train") {
|
||||
val paramMap = Map("tree_method" -> "approx", "device" -> "cuda")
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val thrown = intercept[IllegalArgumentException] {
|
||||
xgb.fit(trainingDF)
|
||||
}
|
||||
assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " +
|
||||
"for Spark GPU cluster"))
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user