[jvm-packages] throw exception when tree_method=approx and device=cuda (#9478)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
05d7000096
commit
344f90b67b
@ -93,12 +93,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
|
|
||||||
private val overridedParams = overrideParams(rawParams, sc)
|
private val overridedParams = overrideParams(rawParams, sc)
|
||||||
|
|
||||||
|
validateSparkSslConf()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||||
* If so, throw an exception unless this safety measure has been explicitly overridden
|
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||||
* via conf `xgboost.spark.ignoreSsl`.
|
* via conf `xgboost.spark.ignoreSsl`.
|
||||||
*/
|
*/
|
||||||
private def validateSparkSslConf: Unit = {
|
private def validateSparkSslConf(): Unit = {
|
||||||
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
||||||
SparkSession.getActiveSession match {
|
SparkSession.getActiveSession match {
|
||||||
case Some(ss) =>
|
case Some(ss) =>
|
||||||
@ -148,55 +150,59 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
overridedParams
|
overridedParams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The Map parameters accepted by estimator's constructor may have string type,
|
||||||
|
* Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
|
||||||
|
* kind of parameters into the correct type in the function.
|
||||||
|
*
|
||||||
|
* @return XGBoostExecutionParams
|
||||||
|
*/
|
||||||
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
||||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
|
||||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
|
||||||
val useExternalMemory = overridedParams
|
|
||||||
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
|
|
||||||
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||||
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
|
||||||
val allowNonZeroForMissing = overridedParams
|
|
||||||
.getOrElse("allow_non_zero_for_missing", false)
|
|
||||||
.asInstanceOf[Boolean]
|
|
||||||
validateSparkSslConf
|
|
||||||
var treeMethod: Option[String] = None
|
|
||||||
if (overridedParams.contains("tree_method")) {
|
|
||||||
require(overridedParams("tree_method") == "hist" ||
|
|
||||||
overridedParams("tree_method") == "approx" ||
|
|
||||||
overridedParams("tree_method") == "auto" ||
|
|
||||||
overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" +
|
|
||||||
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
|
||||||
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
|
||||||
}
|
|
||||||
|
|
||||||
// back-compatible with "gpu_hist"
|
|
||||||
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
|
|
||||||
Some("cuda")
|
|
||||||
} else overridedParams.get("device").map(_.toString)
|
|
||||||
|
|
||||||
if (overridedParams.contains("train_test_ratio")) {
|
|
||||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
|
||||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
|
||||||
"'eval_set_names'")
|
|
||||||
}
|
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
|
||||||
if (obj != null) {
|
if (obj != null) {
|
||||||
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
|
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
|
||||||
"is not defined, you have to specify the objective type as classification or regression" +
|
"is not defined, you have to specify the objective type as classification or regression" +
|
||||||
" with a customized objective function")
|
" with a customized objective function")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var trainTestRatio = 1.0
|
||||||
|
if (overridedParams.contains("train_test_ratio")) {
|
||||||
|
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||||
|
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||||
|
"'eval_set_names'")
|
||||||
|
trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double]
|
||||||
|
}
|
||||||
|
|
||||||
|
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||||
|
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||||
|
val useExternalMemory = overridedParams
|
||||||
|
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
|
||||||
|
|
||||||
|
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||||
|
val allowNonZeroForMissing = overridedParams
|
||||||
|
.getOrElse("allow_non_zero_for_missing", false)
|
||||||
|
.asInstanceOf[Boolean]
|
||||||
|
|
||||||
|
val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
|
||||||
|
// back-compatible with "gpu_hist"
|
||||||
|
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
|
||||||
|
Some("cuda")
|
||||||
|
} else overridedParams.get("device").map(_.toString)
|
||||||
|
|
||||||
|
require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")),
|
||||||
|
"The tree method \"approx\" is not yet supported for Spark GPU cluster")
|
||||||
|
|
||||||
val trackerConf = overridedParams.get("tracker_conf") match {
|
val trackerConf = overridedParams.get("tracker_conf") match {
|
||||||
case None => TrackerConf()
|
case None => TrackerConf()
|
||||||
case Some(conf: TrackerConf) => conf
|
case Some(conf: TrackerConf) => conf
|
||||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||||
"instance of TrackerConf.")
|
"instance of TrackerConf.")
|
||||||
}
|
}
|
||||||
val checkpointParam =
|
|
||||||
ExternalCheckpointParams.extractParams(overridedParams)
|
|
||||||
|
|
||||||
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams)
|
||||||
.asInstanceOf[Double]
|
|
||||||
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
|
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
|
||||||
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
|
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
|
||||||
|
|
||||||
|
|||||||
@ -68,11 +68,13 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
/**
|
/**
|
||||||
* Fraction of training points to use for testing.
|
* Fraction of training points to use for testing.
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||||
"fraction of training points to use for testing",
|
"fraction of training points to use for testing",
|
||||||
ParamValidators.inRange(0, 1))
|
ParamValidators.inRange(0, 1))
|
||||||
setDefault(trainTestRatio, 1.0)
|
setDefault(trainTestRatio, 1.0)
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
final def getTrainTestRatio: Double = $(trainTestRatio)
|
final def getTrainTestRatio: Double = $(trainTestRatio)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -92,4 +92,15 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
|||||||
classifier.getBaseScore
|
classifier.getBaseScore
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("approx can't be used for gpu train") {
|
||||||
|
val paramMap = Map("tree_method" -> "approx", "device" -> "cuda")
|
||||||
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
val thrown = intercept[IllegalArgumentException] {
|
||||||
|
xgb.fit(trainingDF)
|
||||||
|
}
|
||||||
|
assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " +
|
||||||
|
"for Spark GPU cluster"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user