[jvm-packages] cancel job instead of killing SparkContext (#6019)

* cancel job instead of killing SparkContext

This PR changes the default behavior that kills SparkContext. Instead, This PR
cancels jobs when coming across task failed. That means the SparkContext is
still alive even some exceptions happen.

* add a parameter to control if killing SparkContext

* cancel the jobs the failed task belongs to

* remove the jobId from the map when one job failed.

* resolve comments
This commit is contained in:
Bobby Wang
2020-09-03 05:20:59 +08:00
committed by GitHub
parent 3912f3de06
commit 0e2d5669f6
5 changed files with 163 additions and 12 deletions

View File

@@ -116,4 +116,28 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
assert(waitAndCheckSparkShutdown(100) == true)
}
}
test("test SparkContext should not be killed ") {
val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Rabit.mockList = Array("0,8,0,0").toList.asJava
try {
new XGBoostClassifier(Map(
"eta" -> "0.1",
"max_depth" -> "10",
"verbosity" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers,
"kill_spark_context_on_worker_failure" -> false,
"rabit_timeout" -> 0))
.fit(training)
} catch {
case e: Throwable => // swallow anything
} finally {
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}
}

View File

@@ -34,6 +34,15 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {
.config("spark.driver.memory", "512m")
.config("spark.task.cpus", 1)
private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = {
var totalWaitedTime = 0L
while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) {
Thread.sleep(100)
totalWaitedTime += 100
}
ss.sparkContext.isStopped
}
test("tracker should not affect execution result when timeout is not larger than 0") {
val nWorkers = numParallelism
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
@@ -74,4 +83,69 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {
}
}
}
test("tracker should not kill SparkContext when killSparkContextOnWorkerFailure=false") {
val nWorkers = numParallelism
val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false)
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers)
try {
tracker.execute {
rdd.map { i =>
val partitionId = TaskContext.get().partitionId()
if (partitionId == 0) {
throw new RuntimeException("mocking task failing")
}
i
}.sum()
}
} catch {
case e: Exception => // catch the exception
} finally {
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}
test("tracker should cancel the correct job when killSparkContextOnWorkerFailure=false") {
val nWorkers = 2
val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false)
val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers)
val thread = new TestThread(sc)
thread.start()
try {
tracker.execute {
rdd.map { i =>
Thread.sleep(100)
val partitionId = TaskContext.get().partitionId()
if (partitionId == 0) {
throw new RuntimeException("mocking task failing")
}
i
}.sum()
}
} catch {
case e: Exception => // catch the exception
} finally {
thread.join(8000)
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}
private[this] class TestThread(sc: SparkContext) extends Thread {
override def run(): Unit = {
var sum: Double = 0.0f
try {
val rdd = sc.parallelize(1 to 4, 2)
sum = rdd.mapPartitions(iter => {
// sleep 2s to ensure task is alive when cancelling other jobs
Thread.sleep(2000)
iter
}).sum()
} finally {
// get the correct result
assert(sum.toInt == 10)
}
}
}
}