diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 403ea73f3..4b62569dd 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -20,7 +20,7 @@ import java.net.URL import org.apache.commons.logging.LogFactory -import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorRemoved, SparkListenerTaskEnd} +import org.apache.spark.scheduler._ import org.codehaus.jackson.map.ObjectMapper import scala.collection.JavaConverters._ import scala.concurrent.ExecutionContext.Implicits.global @@ -114,21 +114,26 @@ class SparkParallelismTracker( } } -private class ErrorInXGBoostTraining(msg: String) extends ControlThrowable { - override def toString: String = s"ErrorInXGBoostTraining: $msg" -} - private[spark] class TaskFailedListener extends SparkListener { + + private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { taskEnd.reason match { - case taskEnd: SparkListenerTaskEnd => - if (taskEnd.reason.isInstanceOf[TaskFailedReason]) { - throw new ErrorInXGBoostTraining(s"TaskFailed during XGBoost Training: " + - s"${taskEnd.reason}") + case taskEndReason: TaskFailedReason => + logger.error(s"Training Task Failed during XGBoost Training: " + + s"$taskEndReason, stopping SparkContext") + // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it + // in a separate thread + val sparkContextKiller = new Thread() { + override def run(): Unit = { + LiveListenerBus.withinListenerThread.withValue(false) { + SparkContext.getOrCreate().stop() + } + } } - case executorRemoved: SparkListenerExecutorRemoved => - throw new ErrorInXGBoostTraining(s"Executor lost during XGBoost Training: " + - s"${executorRemoved.reason}") + sparkContextKiller.setDaemon(true) + sparkContextKiller.start() case _ => } }