[jvm-packages] cancel job instead of killing SparkContext (#6019)

* cancel job instead of killing SparkContext

This PR changes the default behavior that kills SparkContext. Instead, This PR
cancels jobs when coming across task failed. That means the SparkContext is
still alive even some exceptions happen.

* add a parameter to control if killing SparkContext

* cancel the jobs the failed task belongs to

* remove the jobId from the map when one job failed.

* resolve comments
This commit is contained in:
Bobby Wang
2020-09-03 05:20:59 +08:00
committed by GitHub
parent 3912f3de06
commit 0e2d5669f6
5 changed files with 163 additions and 12 deletions

View File

@@ -76,7 +76,8 @@ private[this] case class XGBoostExecutionParams(
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean,
treeMethod: Option[String],
isLocal: Boolean) {
isLocal: Boolean,
killSparkContextOnWorkerFailure: Boolean) {
private var rawParamMap: Map[String, Any] = _
@@ -220,6 +221,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
.asInstanceOf[Boolean]
val killSparkContext = overridedParams.getOrElse("kill_spark_context_on_worker_failure", true)
.asInstanceOf[Boolean]
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
missing, allowNonZeroForMissing, trackerConf,
timeoutRequestWorkers,
@@ -228,7 +232,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecEarlyStoppingParams,
cacheTrainingSet,
treeMethod,
isLocal)
isLocal,
killSparkContext)
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
@@ -588,7 +593,8 @@ object XGBoost extends Serializable {
val (booster, metrics) = try {
val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers)
xgbExecParams.numWorkers,
xgbExecParams.killSparkContextOnWorkerFailure)
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
@@ -628,7 +634,9 @@ object XGBoost extends Serializable {
case t: Throwable =>
// if the job was aborted due to an exception
logger.error("the job was aborted due to ", t)
trainingData.sparkContext.stop()
if (xgbExecParams.killSparkContextOnWorkerFailure) {
trainingData.sparkContext.stop()
}
throw t
} finally {
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)

View File

@@ -105,8 +105,14 @@ private[spark] trait LearningTaskParams extends Params {
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
/**
* whether killing SparkContext when training task fails
*/
final val killSparkContextOnWorkerFailure = new BooleanParam(this,
"killSparkContextOnWorkerFailure", "whether killing SparkContext when training task fails")
setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0,
numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContextOnWorkerFailure -> true)
}
private[spark] object LearningTaskParams {

View File

@@ -19,6 +19,8 @@ package org.apache.spark
import org.apache.commons.logging.LogFactory
import org.apache.spark.scheduler._
import scala.collection.mutable.{HashMap, HashSet}
/**
* A tracker that ensures enough number of executor cores are alive.
* Throws an exception when the number of alive cores is less than nWorkers.
@@ -26,11 +28,13 @@ import org.apache.spark.scheduler._
* @param sc The SparkContext object
* @param timeout The maximum time to wait for enough number of workers.
* @param numWorkers nWorkers used in an XGBoost Job
* @param killSparkContextOnWorkerFailure kill SparkContext or not when task fails
*/
class SparkParallelismTracker(
val sc: SparkContext,
timeout: Long,
numWorkers: Int) {
numWorkers: Int,
killSparkContextOnWorkerFailure: Boolean = true) {
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
private[this] val logger = LogFactory.getLog("XGBoostSpark")
@@ -58,7 +62,7 @@ class SparkParallelismTracker(
}
private[this] def safeExecute[T](body: => T): T = {
val listener = new TaskFailedListener
val listener = new TaskFailedListener(killSparkContextOnWorkerFailure)
sc.addSparkListener(listener)
try {
body
@@ -79,7 +83,7 @@ class SparkParallelismTracker(
def execute[T](body: => T): T = {
if (timeout <= 0) {
logger.info("starting training without setting timeout for waiting for resources")
body
safeExecute(body)
} else {
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
if (!waitForCondition(numAliveCores >= requestedCores, timeout)) {
@@ -90,16 +94,51 @@ class SparkParallelismTracker(
}
}
private[spark] class TaskFailedListener extends SparkListener {
class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener {
private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener")
// {jobId, [stageId0, stageId1, ...] }
// keep track of the mapping of job id and stage ids
// when a task fails, find the job id and stage id the task belongs to, finally
// cancel the jobs
private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
if (!killSparkContext) {
jobStart.stageIds.foreach(stageId => {
jobIdToStageIds.getOrElseUpdate(jobStart.jobId, new HashSet[Int]()) += stageId
})
}
}
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
if (!killSparkContext) {
jobIdToStageIds.remove(jobEnd.jobId)
}
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskEnd.reason match {
case taskEndReason: TaskFailedReason =>
logger.error(s"Training Task Failed during XGBoost Training: " +
s"$taskEndReason, stopping SparkContext")
TaskFailedListener.startedSparkContextKiller()
s"$taskEndReason")
if (killSparkContext) {
logger.error("killing SparkContext")
TaskFailedListener.startedSparkContextKiller()
} else {
val stageId = taskEnd.stageId
// find job ids according to stage id and then cancel the job
jobIdToStageIds.foreach {
case (jobId, stageIds) =>
if (stageIds.contains(stageId)) {
logger.error("Cancelling jobId:" + jobId)
jobIdToStageIds.remove(jobId)
SparkContext.getOrCreate().cancelJob(jobId)
}
}
}
case _ =>
}
}