[jvm-packages] xgboost-spark warning when Spark encryption is turned on (#3667)
* added test, commented out right now * reinstated test * added fix for checking encryption settings * fix by using RDD conf * fix compilation * renamed conf * use SparkSession if available * fix message * nop * code review fixes
This commit is contained in:
parent
3564b68b98
commit
14a8b96476
@ -31,6 +31,7 @@ import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
|
||||
/**
|
||||
@ -174,6 +175,38 @@ object XGBoost extends Serializable {
|
||||
tracker
|
||||
}
|
||||
|
||||
/**
|
||||
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||
* via conf `xgboost.spark.ignoreSsl`.
|
||||
*
|
||||
* @param sc SparkContext for the training dataset. When looking for the confs, this method
|
||||
* first checks for an active SparkSession. If one is not available, it falls back
|
||||
* to this SparkContext.
|
||||
*/
|
||||
private def validateSparkSslConf(sc: SparkContext): Unit = {
|
||||
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
||||
SparkSession.getActiveSession match {
|
||||
case Some(ss) =>
|
||||
(ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
|
||||
ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
|
||||
case None =>
|
||||
(sc.getConf.getBoolean("spark.ssl.enabled", false),
|
||||
sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
|
||||
}
|
||||
if (sparkSslEnabled) {
|
||||
if (xgboostSparkIgnoreSsl) {
|
||||
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
|
||||
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
|
||||
} else {
|
||||
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
|
||||
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
|
||||
"To override this protection and still use xgboost-spark at your own risk, " +
|
||||
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A tuple of the booster and the metrics used to build training summary
|
||||
*/
|
||||
@ -187,6 +220,7 @@ object XGBoost extends Serializable {
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
|
||||
validateSparkSslConf(trainingData.context)
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
" for now")
|
||||
|
||||
@ -47,4 +47,34 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
|
||||
val eval = new EvalError()
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
test("Check for Spark encryption over-the-wire") {
|
||||
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
|
||||
ss.conf.set("spark.ssl.enabled", true)
|
||||
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
|
||||
withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " +
|
||||
"xgboost.spark.ignoreSsl != true") {
|
||||
val thrown = intercept[Exception] {
|
||||
new XGBoostClassifier(paramMap).fit(training)
|
||||
}
|
||||
assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") &&
|
||||
thrown.getMessage.contains("spark.ssl.enabled"))
|
||||
}
|
||||
|
||||
// Confirm that this check can be overridden.
|
||||
ss.conf.set("xgboost.spark.ignoreSsl", true)
|
||||
new XGBoostClassifier(paramMap).fit(training)
|
||||
|
||||
originalSslConfOpt match {
|
||||
case None =>
|
||||
ss.conf.unset("spark.ssl.enabled")
|
||||
case Some(originalSslConf) =>
|
||||
ss.conf.set("spark.ssl.enabled", originalSslConf)
|
||||
}
|
||||
ss.conf.unset("xgboost.spark.ignoreSsl")
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user