diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 3e514ebd8..99c1cccf2 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -146,22 +146,30 @@ class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener object TaskFailedListener { - var killerStarted = false + var killerStarted: Boolean = false + + var sparkContextKiller: Thread = _ + + val sparkContextShutdownLock = new AnyRef private def startedSparkContextKiller(): Unit = this.synchronized { if (!killerStarted) { + killerStarted = true // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it // in a separate thread - val sparkContextKiller = new Thread() { + sparkContextKiller = new Thread() { override def run(): Unit = { LiveListenerBus.withinListenerThread.withValue(false) { - SparkContext.getOrCreate().stop() + sparkContextShutdownLock.synchronized { + SparkContext.getActive.foreach(_.stop()) + killerStarted = false + sparkContextShutdownLock.notify() + } } } } sparkContextKiller.setDaemon(true) sparkContextKiller.start() - killerStarted = true } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 341db97bc..6148e6dbe 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -45,12 +45,26 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => override def beforeEach(): Unit = getOrCreateSession override def afterEach() { - synchronized { + TaskFailedListener.sparkContextShutdownLock.synchronized { if (currentSession != null) { + // this synchronization is mostly for the tests involving SparkContext shutdown + // for unit test involving the sparkContext shutdown there are two different events sequence + // 1. SparkContext killer is executed before afterEach, in this case, before SparkContext + // is fully stopped, afterEach() will block at the following code block + // 2. SparkContext killer is executed afterEach, in this case, currentSession.stop() in will + // block to wait for all msgs in ListenerBus get processed. Because currentSession.stop() + // has been called, SparkContext killer will not take effect + while (TaskFailedListener.killerStarted) { + TaskFailedListener.sparkContextShutdownLock.wait() + } currentSession.stop() cleanExternalCache(currentSession.sparkContext.appName) currentSession = null } + if (TaskFailedListener.sparkContextKiller != null) { + TaskFailedListener.sparkContextKiller.interrupt() + TaskFailedListener.sparkContextKiller = null + } TaskFailedListener.killerStarted = false } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index e1d58f26e..2e51f15b0 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -114,6 +114,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { // assume all tasks throw exception almost same time // 100ms should be enough to exhaust all retries assert(waitAndCheckSparkShutdown(100) == true) + TaskFailedListener.killerStarted = false } }