diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index ccc37ebea..6302c35e4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -31,6 +31,7 @@ import org.apache.commons.io.FileUtils import org.apache.commons.logging.LogFactory import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext} +import org.apache.spark.sql.SparkSession /** @@ -174,6 +175,38 @@ object XGBoost extends Serializable { tracker } + /** + * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true). + * If so, throw an exception unless this safety measure has been explicitly overridden + * via conf `xgboost.spark.ignoreSsl`. + * + * @param sc SparkContext for the training dataset. When looking for the confs, this method + * first checks for an active SparkSession. If one is not available, it falls back + * to this SparkContext. + */ + private def validateSparkSslConf(sc: SparkContext): Unit = { + val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) = + SparkSession.getActiveSession match { + case Some(ss) => + (ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean, + ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean) + case None => + (sc.getConf.getBoolean("spark.ssl.enabled", false), + sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false)) + } + if (sparkSslEnabled) { + if (xgboostSparkIgnoreSsl) { + logger.warn(s"spark-xgboost is being run without encrypting data in transit! " + + s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.") + } else { + throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " + + "in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " + + "To override this protection and still use xgboost-spark at your own risk, " + + "you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.") + } + } + } + /** * @return A tuple of the booster and the metrics used to build training summary */ @@ -187,6 +220,7 @@ object XGBoost extends Serializable { eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = { + validateSparkSslConf(trainingData.context) if (params.contains("tree_method")) { require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" + " for now") diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala index 3b9eae707..fe16bcda5 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala @@ -47,4 +47,34 @@ class XGBoostConfigureSuite extends FunSuite with PerTest { val eval = new EvalError() assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) } + + test("Check for Spark encryption over-the-wire") { + val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled") + ss.conf.set("spark.ssl.enabled", true) + + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers) + val training = buildDataFrame(Classification.train) + + withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " + + "xgboost.spark.ignoreSsl != true") { + val thrown = intercept[Exception] { + new XGBoostClassifier(paramMap).fit(training) + } + assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") && + thrown.getMessage.contains("spark.ssl.enabled")) + } + + // Confirm that this check can be overridden. + ss.conf.set("xgboost.spark.ignoreSsl", true) + new XGBoostClassifier(paramMap).fit(training) + + originalSslConfOpt match { + case None => + ss.conf.unset("spark.ssl.enabled") + case Some(originalSslConf) => + ss.conf.set("spark.ssl.enabled", originalSslConf) + } + ss.conf.unset("xgboost.spark.ignoreSsl") + } }