[jvm-packages] xgboost-spark warning when Spark encryption is turned on (#3667)
* added test, commented out right now * reinstated test * added fix for checking encryption settings * fix by using RDD conf * fix compilation * renamed conf * use SparkSession if available * fix message * nop * code review fixes
This commit is contained in:
parent
3564b68b98
commit
14a8b96476
@ -31,6 +31,7 @@ import org.apache.commons.io.FileUtils
|
|||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -174,6 +175,38 @@ object XGBoost extends Serializable {
|
|||||||
tracker
|
tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||||
|
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||||
|
* via conf `xgboost.spark.ignoreSsl`.
|
||||||
|
*
|
||||||
|
* @param sc SparkContext for the training dataset. When looking for the confs, this method
|
||||||
|
* first checks for an active SparkSession. If one is not available, it falls back
|
||||||
|
* to this SparkContext.
|
||||||
|
*/
|
||||||
|
private def validateSparkSslConf(sc: SparkContext): Unit = {
|
||||||
|
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
||||||
|
SparkSession.getActiveSession match {
|
||||||
|
case Some(ss) =>
|
||||||
|
(ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
|
||||||
|
ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
|
||||||
|
case None =>
|
||||||
|
(sc.getConf.getBoolean("spark.ssl.enabled", false),
|
||||||
|
sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
|
||||||
|
}
|
||||||
|
if (sparkSslEnabled) {
|
||||||
|
if (xgboostSparkIgnoreSsl) {
|
||||||
|
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
|
||||||
|
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
|
||||||
|
} else {
|
||||||
|
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
|
||||||
|
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
|
||||||
|
"To override this protection and still use xgboost-spark at your own risk, " +
|
||||||
|
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return A tuple of the booster and the metrics used to build training summary
|
* @return A tuple of the booster and the metrics used to build training summary
|
||||||
*/
|
*/
|
||||||
@ -187,6 +220,7 @@ object XGBoost extends Serializable {
|
|||||||
eval: EvalTrait = null,
|
eval: EvalTrait = null,
|
||||||
useExternalMemory: Boolean = false,
|
useExternalMemory: Boolean = false,
|
||||||
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
|
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
|
||||||
|
validateSparkSslConf(trainingData.context)
|
||||||
if (params.contains("tree_method")) {
|
if (params.contains("tree_method")) {
|
||||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||||
" for now")
|
" for now")
|
||||||
|
|||||||
@ -47,4 +47,34 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
|
|||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Check for Spark encryption over-the-wire") {
|
||||||
|
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
|
||||||
|
ss.conf.set("spark.ssl.enabled", true)
|
||||||
|
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
|
||||||
|
withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " +
|
||||||
|
"xgboost.spark.ignoreSsl != true") {
|
||||||
|
val thrown = intercept[Exception] {
|
||||||
|
new XGBoostClassifier(paramMap).fit(training)
|
||||||
|
}
|
||||||
|
assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") &&
|
||||||
|
thrown.getMessage.contains("spark.ssl.enabled"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Confirm that this check can be overridden.
|
||||||
|
ss.conf.set("xgboost.spark.ignoreSsl", true)
|
||||||
|
new XGBoostClassifier(paramMap).fit(training)
|
||||||
|
|
||||||
|
originalSslConfOpt match {
|
||||||
|
case None =>
|
||||||
|
ss.conf.unset("spark.ssl.enabled")
|
||||||
|
case Some(originalSslConf) =>
|
||||||
|
ss.conf.set("spark.ssl.enabled", originalSslConf)
|
||||||
|
}
|
||||||
|
ss.conf.unset("xgboost.spark.ignoreSsl")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user