[jvm-packages] Add SparkParallelismTracker to prevent job from hanging (#2697)
* Add SparkParallelismTracker to prevent job from hanging * Code review comments * Code Review Comments * Fix unit tests * Changes and unit test to catch the corner case. * Update documentations * Small improvements * cancalAllJobs is problematic with scalatest. Remove it * Code Review Comments * Check number of executor cores beforehand, and throw exeception if any core is lost. * Address CR Comments * Add missing class * Fix flaky unit test * Address CR comments * Remove redundant param for TaskFailedListener
This commit is contained in:
parent
78c4188cec
commit
b678e1711d
@ -23,13 +23,12 @@ import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
@ -315,8 +314,17 @@ object XGBoost extends Serializable {
|
||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||
"instance of TrackerConf.")
|
||||
}
|
||||
val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match {
|
||||
case None => 0L
|
||||
case Some(interval: Long) => interval
|
||||
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
||||
" an instance of Long.")
|
||||
}
|
||||
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
try {
|
||||
val sc = trainingData.sparkContext
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext)
|
||||
val boosters = buildDistributedBoosters(trainingData, overriddenParams,
|
||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
@ -329,7 +337,7 @@ object XGBoost extends Serializable {
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams,
|
||||
sparkJobThread, isClsTask)
|
||||
|
||||
@ -70,6 +70,13 @@ trait GeneralParams extends Params {
|
||||
*/
|
||||
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||
|
||||
/**
|
||||
* the interval to check whether total numCores is no smaller than nWorkers. default: 30 minutes
|
||||
*/
|
||||
val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" +
|
||||
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
|
||||
" value is set smaller than or equal to 0.")
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||
* TrackerConf class, which has the following definition:
|
||||
@ -105,6 +112,6 @@ trait GeneralParams extends Params {
|
||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||
useExternalMemory -> false, silent -> 0,
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||
trackerConf -> TrackerConf(), seed -> 0
|
||||
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L
|
||||
)
|
||||
}
|
||||
|
||||
@ -0,0 +1,116 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark
|
||||
|
||||
import java.net.URL
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
|
||||
import org.codehaus.jackson.map.ObjectMapper
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.{Await, Future, TimeoutException}
|
||||
|
||||
/**
|
||||
* A tracker that ensures enough number of executor cores are alive.
|
||||
* Throws an exception when the number of alive cores is less than nWorkers.
|
||||
*
|
||||
* @param sc The SparkContext object
|
||||
* @param timeout The maximum time to wait for enough number of workers.
|
||||
* @param nWorkers nWorkers used in an XGBoost Job
|
||||
*/
|
||||
class SparkParallelismTracker(
|
||||
val sc: SparkContext,
|
||||
timeout: Long,
|
||||
nWorkers: Int) {
|
||||
|
||||
private[this] val mapper = new ObjectMapper()
|
||||
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
||||
private[this] val url = sc.uiWebUrl match {
|
||||
case Some(baseUrl) => new URL(s"$baseUrl/api/v1/applications/${sc.applicationId}/executors")
|
||||
case _ => null
|
||||
}
|
||||
|
||||
private[this] def numAliveCores: Int = {
|
||||
try {
|
||||
mapper.readTree(url).findValues("totalCores").asScala.map(_.asInt).sum
|
||||
} catch {
|
||||
case ex: Throwable =>
|
||||
logger.warn(s"Unable to read total number of alive cores from REST API." +
|
||||
s"Health Check will be ignored.")
|
||||
ex.printStackTrace()
|
||||
Int.MaxValue
|
||||
}
|
||||
}
|
||||
|
||||
private[this] def waitForCondition(
|
||||
condition: => Boolean,
|
||||
timeout: Long,
|
||||
checkInterval: Long = 100L) = {
|
||||
val monitor = Future {
|
||||
while (!condition) {
|
||||
Thread.sleep(checkInterval)
|
||||
}
|
||||
}
|
||||
Await.ready(monitor, timeout.millis)
|
||||
}
|
||||
|
||||
private[this] def safeExecute[T](body: => T): T = {
|
||||
sc.listenerBus.listeners.add(0, new TaskFailedListener)
|
||||
try {
|
||||
body
|
||||
} finally {
|
||||
sc.listenerBus.listeners.remove(0)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a blocking function call with two checks on enough nWorkers:
|
||||
* - Before the function starts, wait until there are enough executor cores.
|
||||
* - During the execution, throws an exception if there is any executor lost.
|
||||
*
|
||||
* @param body A blocking function call
|
||||
* @tparam T Return type
|
||||
* @return The return of body
|
||||
*/
|
||||
def execute[T](body: => T): T = {
|
||||
if (timeout <= 0) {
|
||||
body
|
||||
} else {
|
||||
try {
|
||||
waitForCondition(numAliveCores >= nWorkers, timeout)
|
||||
} catch {
|
||||
case _: TimeoutException =>
|
||||
throw new IllegalStateException(s"Unable to get $nWorkers workers for XGBoost training")
|
||||
}
|
||||
safeExecute(body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class TaskFailedListener extends SparkListener {
|
||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
|
||||
taskEnd.reason match {
|
||||
case reason: TaskFailedReason =>
|
||||
throw new InterruptedException(s"ExecutorLost during XGBoost Training: " +
|
||||
s"${reason.toErrorString}")
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -376,4 +376,17 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
val predResult1: Array[Array[Float]] = predRDD.collect()
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
}
|
||||
|
||||
test("training with spark parallelism checks disabled") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "timeout_request_workers" -> 0L).toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,57 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
|
||||
class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
var sc: SparkContext = _
|
||||
var numParallelism: Int = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
val conf: SparkConf = new SparkConf()
|
||||
.setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite")
|
||||
sc = new SparkContext(conf)
|
||||
numParallelism = sc.defaultParallelism
|
||||
}
|
||||
|
||||
test("tracker should not affect execution result") {
|
||||
val nWorkers = numParallelism
|
||||
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
||||
val tracker = new SparkParallelismTracker(sc, 10000, nWorkers)
|
||||
val disabledTracker = new SparkParallelismTracker(sc, 0, nWorkers)
|
||||
assert(tracker.execute(rdd.sum()) == rdd.sum())
|
||||
assert(disabledTracker.execute(rdd.sum()) == rdd.sum())
|
||||
}
|
||||
|
||||
test("tracker should throw exception if parallelism is not sufficient") {
|
||||
val nWorkers = numParallelism * 3
|
||||
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
||||
val tracker = new SparkParallelismTracker(sc, 1000, nWorkers)
|
||||
intercept[IllegalStateException] {
|
||||
tracker.execute {
|
||||
rdd.map { i =>
|
||||
// Test interruption
|
||||
Thread.sleep(Long.MaxValue)
|
||||
i
|
||||
}.sum()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user