[jvm-packages] cancel job instead of killing SparkContext (#6019)
* cancel job instead of killing SparkContext This PR changes the default behavior that kills SparkContext. Instead, This PR cancels jobs when coming across task failed. That means the SparkContext is still alive even some exceptions happen. * add a parameter to control if killing SparkContext * cancel the jobs the failed task belongs to * remove the jobId from the map when one job failed. * resolve comments
This commit is contained in:
parent
3912f3de06
commit
0e2d5669f6
@ -76,7 +76,8 @@ private[this] case class XGBoostExecutionParams(
|
|||||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||||
cacheTrainingSet: Boolean,
|
cacheTrainingSet: Boolean,
|
||||||
treeMethod: Option[String],
|
treeMethod: Option[String],
|
||||||
isLocal: Boolean) {
|
isLocal: Boolean,
|
||||||
|
killSparkContextOnWorkerFailure: Boolean) {
|
||||||
|
|
||||||
private var rawParamMap: Map[String, Any] = _
|
private var rawParamMap: Map[String, Any] = _
|
||||||
|
|
||||||
@ -220,6 +221,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
||||||
.asInstanceOf[Boolean]
|
.asInstanceOf[Boolean]
|
||||||
|
|
||||||
|
val killSparkContext = overridedParams.getOrElse("kill_spark_context_on_worker_failure", true)
|
||||||
|
.asInstanceOf[Boolean]
|
||||||
|
|
||||||
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
||||||
missing, allowNonZeroForMissing, trackerConf,
|
missing, allowNonZeroForMissing, trackerConf,
|
||||||
timeoutRequestWorkers,
|
timeoutRequestWorkers,
|
||||||
@ -228,7 +232,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
xgbExecEarlyStoppingParams,
|
xgbExecEarlyStoppingParams,
|
||||||
cacheTrainingSet,
|
cacheTrainingSet,
|
||||||
treeMethod,
|
treeMethod,
|
||||||
isLocal)
|
isLocal,
|
||||||
|
killSparkContext)
|
||||||
xgbExecParam.setRawParamMap(overridedParams)
|
xgbExecParam.setRawParamMap(overridedParams)
|
||||||
xgbExecParam
|
xgbExecParam
|
||||||
}
|
}
|
||||||
@ -588,7 +593,8 @@ object XGBoost extends Serializable {
|
|||||||
val (booster, metrics) = try {
|
val (booster, metrics) = try {
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||||
xgbExecParams.timeoutRequestWorkers,
|
xgbExecParams.timeoutRequestWorkers,
|
||||||
xgbExecParams.numWorkers)
|
xgbExecParams.numWorkers,
|
||||||
|
xgbExecParams.killSparkContextOnWorkerFailure)
|
||||||
val rabitEnv = tracker.getWorkerEnvs
|
val rabitEnv = tracker.getWorkerEnvs
|
||||||
val boostersAndMetrics = if (hasGroup) {
|
val boostersAndMetrics = if (hasGroup) {
|
||||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
|
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
|
||||||
@ -628,7 +634,9 @@ object XGBoost extends Serializable {
|
|||||||
case t: Throwable =>
|
case t: Throwable =>
|
||||||
// if the job was aborted due to an exception
|
// if the job was aborted due to an exception
|
||||||
logger.error("the job was aborted due to ", t)
|
logger.error("the job was aborted due to ", t)
|
||||||
trainingData.sparkContext.stop()
|
if (xgbExecParams.killSparkContextOnWorkerFailure) {
|
||||||
|
trainingData.sparkContext.stop()
|
||||||
|
}
|
||||||
throw t
|
throw t
|
||||||
} finally {
|
} finally {
|
||||||
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)
|
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)
|
||||||
|
|||||||
@ -105,8 +105,14 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
|
|
||||||
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
||||||
|
|
||||||
setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
|
/**
|
||||||
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
|
* whether killing SparkContext when training task fails
|
||||||
|
*/
|
||||||
|
final val killSparkContextOnWorkerFailure = new BooleanParam(this,
|
||||||
|
"killSparkContextOnWorkerFailure", "whether killing SparkContext when training task fails")
|
||||||
|
|
||||||
|
setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0,
|
||||||
|
numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContextOnWorkerFailure -> true)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] object LearningTaskParams {
|
private[spark] object LearningTaskParams {
|
||||||
|
|||||||
@ -19,6 +19,8 @@ package org.apache.spark
|
|||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.spark.scheduler._
|
import org.apache.spark.scheduler._
|
||||||
|
|
||||||
|
import scala.collection.mutable.{HashMap, HashSet}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A tracker that ensures enough number of executor cores are alive.
|
* A tracker that ensures enough number of executor cores are alive.
|
||||||
* Throws an exception when the number of alive cores is less than nWorkers.
|
* Throws an exception when the number of alive cores is less than nWorkers.
|
||||||
@ -26,11 +28,13 @@ import org.apache.spark.scheduler._
|
|||||||
* @param sc The SparkContext object
|
* @param sc The SparkContext object
|
||||||
* @param timeout The maximum time to wait for enough number of workers.
|
* @param timeout The maximum time to wait for enough number of workers.
|
||||||
* @param numWorkers nWorkers used in an XGBoost Job
|
* @param numWorkers nWorkers used in an XGBoost Job
|
||||||
|
* @param killSparkContextOnWorkerFailure kill SparkContext or not when task fails
|
||||||
*/
|
*/
|
||||||
class SparkParallelismTracker(
|
class SparkParallelismTracker(
|
||||||
val sc: SparkContext,
|
val sc: SparkContext,
|
||||||
timeout: Long,
|
timeout: Long,
|
||||||
numWorkers: Int) {
|
numWorkers: Int,
|
||||||
|
killSparkContextOnWorkerFailure: Boolean = true) {
|
||||||
|
|
||||||
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
|
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
|
||||||
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
@ -58,7 +62,7 @@ class SparkParallelismTracker(
|
|||||||
}
|
}
|
||||||
|
|
||||||
private[this] def safeExecute[T](body: => T): T = {
|
private[this] def safeExecute[T](body: => T): T = {
|
||||||
val listener = new TaskFailedListener
|
val listener = new TaskFailedListener(killSparkContextOnWorkerFailure)
|
||||||
sc.addSparkListener(listener)
|
sc.addSparkListener(listener)
|
||||||
try {
|
try {
|
||||||
body
|
body
|
||||||
@ -79,7 +83,7 @@ class SparkParallelismTracker(
|
|||||||
def execute[T](body: => T): T = {
|
def execute[T](body: => T): T = {
|
||||||
if (timeout <= 0) {
|
if (timeout <= 0) {
|
||||||
logger.info("starting training without setting timeout for waiting for resources")
|
logger.info("starting training without setting timeout for waiting for resources")
|
||||||
body
|
safeExecute(body)
|
||||||
} else {
|
} else {
|
||||||
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
|
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
|
||||||
if (!waitForCondition(numAliveCores >= requestedCores, timeout)) {
|
if (!waitForCondition(numAliveCores >= requestedCores, timeout)) {
|
||||||
@ -90,16 +94,51 @@ class SparkParallelismTracker(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] class TaskFailedListener extends SparkListener {
|
class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener {
|
||||||
|
|
||||||
private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener")
|
private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener")
|
||||||
|
|
||||||
|
// {jobId, [stageId0, stageId1, ...] }
|
||||||
|
// keep track of the mapping of job id and stage ids
|
||||||
|
// when a task fails, find the job id and stage id the task belongs to, finally
|
||||||
|
// cancel the jobs
|
||||||
|
private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty
|
||||||
|
|
||||||
|
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
|
||||||
|
if (!killSparkContext) {
|
||||||
|
jobStart.stageIds.foreach(stageId => {
|
||||||
|
jobIdToStageIds.getOrElseUpdate(jobStart.jobId, new HashSet[Int]()) += stageId
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
|
||||||
|
if (!killSparkContext) {
|
||||||
|
jobIdToStageIds.remove(jobEnd.jobId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
|
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
|
||||||
taskEnd.reason match {
|
taskEnd.reason match {
|
||||||
case taskEndReason: TaskFailedReason =>
|
case taskEndReason: TaskFailedReason =>
|
||||||
logger.error(s"Training Task Failed during XGBoost Training: " +
|
logger.error(s"Training Task Failed during XGBoost Training: " +
|
||||||
s"$taskEndReason, stopping SparkContext")
|
s"$taskEndReason")
|
||||||
TaskFailedListener.startedSparkContextKiller()
|
if (killSparkContext) {
|
||||||
|
logger.error("killing SparkContext")
|
||||||
|
TaskFailedListener.startedSparkContextKiller()
|
||||||
|
} else {
|
||||||
|
val stageId = taskEnd.stageId
|
||||||
|
// find job ids according to stage id and then cancel the job
|
||||||
|
|
||||||
|
jobIdToStageIds.foreach {
|
||||||
|
case (jobId, stageIds) =>
|
||||||
|
if (stageIds.contains(stageId)) {
|
||||||
|
logger.error("Cancelling jobId:" + jobId)
|
||||||
|
jobIdToStageIds.remove(jobId)
|
||||||
|
SparkContext.getOrCreate().cancelJob(jobId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
case _ =>
|
case _ =>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -116,4 +116,28 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
|||||||
assert(waitAndCheckSparkShutdown(100) == true)
|
assert(waitAndCheckSparkShutdown(100) == true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("test SparkContext should not be killed ") {
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
// mock rank 0 failure during 8th allreduce synchronization
|
||||||
|
Rabit.mockList = Array("0,8,0,0").toList.asJava
|
||||||
|
|
||||||
|
try {
|
||||||
|
new XGBoostClassifier(Map(
|
||||||
|
"eta" -> "0.1",
|
||||||
|
"max_depth" -> "10",
|
||||||
|
"verbosity" -> "1",
|
||||||
|
"objective" -> "binary:logistic",
|
||||||
|
"num_round" -> 5,
|
||||||
|
"num_workers" -> numWorkers,
|
||||||
|
"kill_spark_context_on_worker_failure" -> false,
|
||||||
|
"rabit_timeout" -> 0))
|
||||||
|
.fit(training)
|
||||||
|
} catch {
|
||||||
|
case e: Throwable => // swallow anything
|
||||||
|
} finally {
|
||||||
|
// wait 3s to check if SparkContext is killed
|
||||||
|
assert(waitAndCheckSparkShutdown(3000) == false)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,6 +34,15 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {
|
|||||||
.config("spark.driver.memory", "512m")
|
.config("spark.driver.memory", "512m")
|
||||||
.config("spark.task.cpus", 1)
|
.config("spark.task.cpus", 1)
|
||||||
|
|
||||||
|
private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = {
|
||||||
|
var totalWaitedTime = 0L
|
||||||
|
while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) {
|
||||||
|
Thread.sleep(100)
|
||||||
|
totalWaitedTime += 100
|
||||||
|
}
|
||||||
|
ss.sparkContext.isStopped
|
||||||
|
}
|
||||||
|
|
||||||
test("tracker should not affect execution result when timeout is not larger than 0") {
|
test("tracker should not affect execution result when timeout is not larger than 0") {
|
||||||
val nWorkers = numParallelism
|
val nWorkers = numParallelism
|
||||||
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
||||||
@ -74,4 +83,69 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("tracker should not kill SparkContext when killSparkContextOnWorkerFailure=false") {
|
||||||
|
val nWorkers = numParallelism
|
||||||
|
val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false)
|
||||||
|
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers)
|
||||||
|
try {
|
||||||
|
tracker.execute {
|
||||||
|
rdd.map { i =>
|
||||||
|
val partitionId = TaskContext.get().partitionId()
|
||||||
|
if (partitionId == 0) {
|
||||||
|
throw new RuntimeException("mocking task failing")
|
||||||
|
}
|
||||||
|
i
|
||||||
|
}.sum()
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
case e: Exception => // catch the exception
|
||||||
|
} finally {
|
||||||
|
// wait 3s to check if SparkContext is killed
|
||||||
|
assert(waitAndCheckSparkShutdown(3000) == false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("tracker should cancel the correct job when killSparkContextOnWorkerFailure=false") {
|
||||||
|
val nWorkers = 2
|
||||||
|
val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false)
|
||||||
|
val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers)
|
||||||
|
val thread = new TestThread(sc)
|
||||||
|
thread.start()
|
||||||
|
try {
|
||||||
|
tracker.execute {
|
||||||
|
rdd.map { i =>
|
||||||
|
Thread.sleep(100)
|
||||||
|
val partitionId = TaskContext.get().partitionId()
|
||||||
|
if (partitionId == 0) {
|
||||||
|
throw new RuntimeException("mocking task failing")
|
||||||
|
}
|
||||||
|
i
|
||||||
|
}.sum()
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
case e: Exception => // catch the exception
|
||||||
|
} finally {
|
||||||
|
thread.join(8000)
|
||||||
|
// wait 3s to check if SparkContext is killed
|
||||||
|
assert(waitAndCheckSparkShutdown(3000) == false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[this] class TestThread(sc: SparkContext) extends Thread {
|
||||||
|
override def run(): Unit = {
|
||||||
|
var sum: Double = 0.0f
|
||||||
|
try {
|
||||||
|
val rdd = sc.parallelize(1 to 4, 2)
|
||||||
|
sum = rdd.mapPartitions(iter => {
|
||||||
|
// sleep 2s to ensure task is alive when cancelling other jobs
|
||||||
|
Thread.sleep(2000)
|
||||||
|
iter
|
||||||
|
}).sum()
|
||||||
|
} finally {
|
||||||
|
// get the correct result
|
||||||
|
assert(sum.toInt == 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user