[jvm-packages] xgboost-spark warning when Spark encryption is turned on (#3667)

* added test, commented out right now

* reinstated test

* added fix for checking encryption settings

* fix by using RDD conf

* fix compilation

* renamed conf

* use SparkSession if available

* fix message

* nop

* code review fixes
This commit is contained in:
Joseph Bradley 2018-09-10 14:21:01 -07:00 committed by Nan Zhu
parent 3564b68b98
commit 14a8b96476
2 changed files with 64 additions and 0 deletions

View File

@ -31,6 +31,7 @@ import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
import org.apache.spark.sql.SparkSession
/**
@ -174,6 +175,38 @@ object XGBoost extends Serializable {
tracker
}
/**
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
* If so, throw an exception unless this safety measure has been explicitly overridden
* via conf `xgboost.spark.ignoreSsl`.
*
* @param sc SparkContext for the training dataset. When looking for the confs, this method
* first checks for an active SparkSession. If one is not available, it falls back
* to this SparkContext.
*/
private def validateSparkSslConf(sc: SparkContext): Unit = {
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
SparkSession.getActiveSession match {
case Some(ss) =>
(ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
case None =>
(sc.getConf.getBoolean("spark.ssl.enabled", false),
sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
}
if (sparkSslEnabled) {
if (xgboostSparkIgnoreSsl) {
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
} else {
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
"To override this protection and still use xgboost-spark at your own risk, " +
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
}
}
}
/**
* @return A tuple of the booster and the metrics used to build training summary
*/
@ -187,6 +220,7 @@ object XGBoost extends Serializable {
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
validateSparkSslConf(trainingData.context)
if (params.contains("tree_method")) {
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
" for now")

View File

@ -47,4 +47,34 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
val eval = new EvalError()
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("Check for Spark encryption over-the-wire") {
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
ss.conf.set("spark.ssl.enabled", true)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " +
"xgboost.spark.ignoreSsl != true") {
val thrown = intercept[Exception] {
new XGBoostClassifier(paramMap).fit(training)
}
assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") &&
thrown.getMessage.contains("spark.ssl.enabled"))
}
// Confirm that this check can be overridden.
ss.conf.set("xgboost.spark.ignoreSsl", true)
new XGBoostClassifier(paramMap).fit(training)
originalSslConfOpt match {
case None =>
ss.conf.unset("spark.ssl.enabled")
case Some(originalSslConf) =>
ss.conf.set("spark.ssl.enabled", originalSslConf)
}
ss.conf.unset("xgboost.spark.ignoreSsl")
}
}