[jvm-packages] fix safe execution (#4046)

This commit is contained in:
Nan Zhu 2019-01-05 19:45:37 -08:00 committed by GitHub
parent 6a569b8cd9
commit e290ec9a80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,7 +20,7 @@ import java.net.URL
import org.apache.commons.logging.LogFactory
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorRemoved, SparkListenerTaskEnd}
import org.codehaus.jackson.map.ObjectMapper
import scala.collection.JavaConverters._
import scala.concurrent.ExecutionContext.Implicits.global
@ -98,9 +98,11 @@ class SparkParallelismTracker(
*/
def execute[T](body: => T): T = {
if (timeout <= 0) {
logger.info("starting training without setting timeout for waiting for resources")
body
} else {
try {
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
waitForCondition(numAliveCores >= requestedCores, timeout)
} catch {
case _: TimeoutException =>
@ -119,9 +121,14 @@ private class ErrorInXGBoostTraining(msg: String) extends ControlThrowable {
private[spark] class TaskFailedListener extends SparkListener {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskEnd.reason match {
case reason: TaskFailedReason =>
throw new ErrorInXGBoostTraining(s"ExecutorLost during XGBoost Training: " +
s"${reason.toErrorString}")
case taskEnd: SparkListenerTaskEnd =>
if (taskEnd.reason.isInstanceOf[TaskFailedReason]) {
throw new ErrorInXGBoostTraining(s"TaskFailed during XGBoost Training: " +
s"${taskEnd.reason}")
}
case executorRemoved: SparkListenerExecutorRemoved =>
throw new ErrorInXGBoostTraining(s"Executor lost during XGBoost Training: " +
s"${executorRemoved.reason}")
case _ =>
}
}