[jvm-packages] fix safe execution (#4046)
This commit is contained in:
parent
6a569b8cd9
commit
e290ec9a80
@ -20,7 +20,7 @@ import java.net.URL
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorRemoved, SparkListenerTaskEnd}
|
||||
import org.codehaus.jackson.map.ObjectMapper
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
@ -98,9 +98,11 @@ class SparkParallelismTracker(
|
||||
*/
|
||||
def execute[T](body: => T): T = {
|
||||
if (timeout <= 0) {
|
||||
logger.info("starting training without setting timeout for waiting for resources")
|
||||
body
|
||||
} else {
|
||||
try {
|
||||
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
|
||||
waitForCondition(numAliveCores >= requestedCores, timeout)
|
||||
} catch {
|
||||
case _: TimeoutException =>
|
||||
@ -119,9 +121,14 @@ private class ErrorInXGBoostTraining(msg: String) extends ControlThrowable {
|
||||
private[spark] class TaskFailedListener extends SparkListener {
|
||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
|
||||
taskEnd.reason match {
|
||||
case reason: TaskFailedReason =>
|
||||
throw new ErrorInXGBoostTraining(s"ExecutorLost during XGBoost Training: " +
|
||||
s"${reason.toErrorString}")
|
||||
case taskEnd: SparkListenerTaskEnd =>
|
||||
if (taskEnd.reason.isInstanceOf[TaskFailedReason]) {
|
||||
throw new ErrorInXGBoostTraining(s"TaskFailed during XGBoost Training: " +
|
||||
s"${taskEnd.reason}")
|
||||
}
|
||||
case executorRemoved: SparkListenerExecutorRemoved =>
|
||||
throw new ErrorInXGBoostTraining(s"Executor lost during XGBoost Training: " +
|
||||
s"${executorRemoved.reason}")
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user