diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 7a950eed7..10c6167ae 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -19,13 +19,14 @@ package org.apache.spark import java.net.URL import org.apache.commons.logging.LogFactory + import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.codehaus.jackson.map.ObjectMapper - import scala.collection.JavaConverters._ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.{Await, Future, TimeoutException} +import scala.util.control.ControlThrowable /** * A tracker that ensures enough number of executor cores are alive. @@ -111,11 +112,15 @@ class SparkParallelismTracker( } } +private class ErrorInXGBoostTraining(msg: String) extends ControlThrowable { + override def toString: String = s"ErrorInXGBoostTraining: $msg" +} + private[spark] class TaskFailedListener extends SparkListener { override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { taskEnd.reason match { case reason: TaskFailedReason => - throw new InterruptedException(s"ExecutorLost during XGBoost Training: " + + throw new ErrorInXGBoostTraining(s"ExecutorLost during XGBoost Training: " + s"${reason.toErrorString}") case _ => }