From 05243642bb64a2865d35cc853058e7986cf5686e Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Thu, 7 Feb 2019 09:02:17 -0800 Subject: [PATCH] [jvm-packages] better fix for shutdown applications (#4108) * intentionally failed task * throw exception * more * stop sparkcontext directly * stop from another thread * new scope * use a new thread * daemon threads * don't join the killer thread * remove injected errors * add comments --- .../spark/SparkParallelismTracker.scala | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 403ea73f3..4b62569dd 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -20,7 +20,7 @@ import java.net.URL import org.apache.commons.logging.LogFactory -import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorRemoved, SparkListenerTaskEnd} +import org.apache.spark.scheduler._ import org.codehaus.jackson.map.ObjectMapper import scala.collection.JavaConverters._ import scala.concurrent.ExecutionContext.Implicits.global @@ -114,21 +114,26 @@ class SparkParallelismTracker( } } -private class ErrorInXGBoostTraining(msg: String) extends ControlThrowable { - override def toString: String = s"ErrorInXGBoostTraining: $msg" -} - private[spark] class TaskFailedListener extends SparkListener { + + private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { taskEnd.reason match { - case taskEnd: SparkListenerTaskEnd => - if (taskEnd.reason.isInstanceOf[TaskFailedReason]) { - throw new ErrorInXGBoostTraining(s"TaskFailed during XGBoost Training: " + - s"${taskEnd.reason}") + case taskEndReason: TaskFailedReason => + logger.error(s"Training Task Failed during XGBoost Training: " + + s"$taskEndReason, stopping SparkContext") + // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it + // in a separate thread + val sparkContextKiller = new Thread() { + override def run(): Unit = { + LiveListenerBus.withinListenerThread.withValue(false) { + SparkContext.getOrCreate().stop() + } + } } - case executorRemoved: SparkListenerExecutorRemoved => - throw new ErrorInXGBoostTraining(s"Executor lost during XGBoost Training: " + - s"${executorRemoved.reason}") + sparkContextKiller.setDaemon(true) + sparkContextKiller.start() case _ => } }