From b678e1711d9c3432d21409c4249646a1691707a9 Mon Sep 17 00:00:00 2001 From: Yun Ni Date: Mon, 16 Oct 2017 20:18:47 -0700 Subject: [PATCH] [jvm-packages] Add SparkParallelismTracker to prevent job from hanging (#2697) * Add SparkParallelismTracker to prevent job from hanging * Code review comments * Code Review Comments * Fix unit tests * Changes and unit test to catch the corner case. * Update documentations * Small improvements * cancalAllJobs is problematic with scalatest. Remove it * Code Review Comments * Check number of executor cores beforehand, and throw exeception if any core is lost. * Address CR Comments * Add missing class * Fix flaky unit test * Address CR comments * Remove redundant param for TaskFailedListener --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 16 ++- .../scala/spark/params/GeneralParams.scala | 9 +- .../spark/SparkParallelismTracker.scala | 116 ++++++++++++++++++ .../scala/spark/XGBoostGeneralSuite.scala | 13 ++ .../spark/SparkParallelismTrackerSuite.scala | 57 +++++++++ 5 files changed, 206 insertions(+), 5 deletions(-) create mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index f3ab0cd08..d5ed85d0d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -23,13 +23,12 @@ import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} - import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.{FSDataInputStream, Path} -import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset -import org.apache.spark.{SparkContext, TaskContext} +import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} +import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext} object TrackerConf { def apply(): TrackerConf = TrackerConf(0L, "python") @@ -315,8 +314,17 @@ object XGBoost extends Serializable { case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " + "instance of TrackerConf.") } + val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match { + case None => 0L + case Some(interval: Long) => interval + case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" + + " an instance of Long.") + } + val tracker = startTracker(nWorkers, trackerConf) try { + val sc = trainingData.sparkContext + val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext) val boosters = buildDistributedBoosters(trainingData, overriddenParams, tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing) @@ -329,7 +337,7 @@ object XGBoost extends Serializable { sparkJobThread.setUncaughtExceptionHandler(tracker) sparkJobThread.start() val isClsTask = isClassificationTask(params) - val trackerReturnVal = tracker.waitFor(0L) + val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) logger.info(s"Rabit returns with exit code $trackerReturnVal") val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams, sparkJobThread, isClsTask) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 676c4eb47..96dada6cb 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -70,6 +70,13 @@ trait GeneralParams extends Params { */ val missing = new FloatParam(this, "missing", "the value treated as missing") + /** + * the interval to check whether total numCores is no smaller than nWorkers. default: 30 minutes + */ + val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" + + " request new Workers if numCores are insufficient. The timeout will be disabled if this" + + " value is set smaller than or equal to 0.") + /** * Rabit tracker configurations. The parameter must be provided as an instance of the * TrackerConf class, which has the following definition: @@ -105,6 +112,6 @@ trait GeneralParams extends Params { setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1, useExternalMemory -> false, silent -> 0, customObj -> null, customEval -> null, missing -> Float.NaN, - trackerConf -> TrackerConf(), seed -> 0 + trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L ) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala new file mode 100644 index 000000000..6172a2588 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -0,0 +1,116 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.apache.spark + +import java.net.URL + +import org.apache.commons.logging.LogFactory +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.codehaus.jackson.map.ObjectMapper + +import scala.collection.JavaConverters._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration._ +import scala.concurrent.{Await, Future, TimeoutException} + +/** + * A tracker that ensures enough number of executor cores are alive. + * Throws an exception when the number of alive cores is less than nWorkers. + * + * @param sc The SparkContext object + * @param timeout The maximum time to wait for enough number of workers. + * @param nWorkers nWorkers used in an XGBoost Job + */ +class SparkParallelismTracker( + val sc: SparkContext, + timeout: Long, + nWorkers: Int) { + + private[this] val mapper = new ObjectMapper() + private[this] val logger = LogFactory.getLog("XGBoostSpark") + private[this] val url = sc.uiWebUrl match { + case Some(baseUrl) => new URL(s"$baseUrl/api/v1/applications/${sc.applicationId}/executors") + case _ => null + } + + private[this] def numAliveCores: Int = { + try { + mapper.readTree(url).findValues("totalCores").asScala.map(_.asInt).sum + } catch { + case ex: Throwable => + logger.warn(s"Unable to read total number of alive cores from REST API." + + s"Health Check will be ignored.") + ex.printStackTrace() + Int.MaxValue + } + } + + private[this] def waitForCondition( + condition: => Boolean, + timeout: Long, + checkInterval: Long = 100L) = { + val monitor = Future { + while (!condition) { + Thread.sleep(checkInterval) + } + } + Await.ready(monitor, timeout.millis) + } + + private[this] def safeExecute[T](body: => T): T = { + sc.listenerBus.listeners.add(0, new TaskFailedListener) + try { + body + } finally { + sc.listenerBus.listeners.remove(0) + } + } + + /** + * Execute a blocking function call with two checks on enough nWorkers: + * - Before the function starts, wait until there are enough executor cores. + * - During the execution, throws an exception if there is any executor lost. + * + * @param body A blocking function call + * @tparam T Return type + * @return The return of body + */ + def execute[T](body: => T): T = { + if (timeout <= 0) { + body + } else { + try { + waitForCondition(numAliveCores >= nWorkers, timeout) + } catch { + case _: TimeoutException => + throw new IllegalStateException(s"Unable to get $nWorkers workers for XGBoost training") + } + safeExecute(body) + } + } +} + +private[spark] class TaskFailedListener extends SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskEnd.reason match { + case reason: TaskFailedReason => + throw new InterruptedException(s"ExecutorLost during XGBoost Training: " + + s"${reason.toErrorString}") + case _ => + } + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 0b96a5f2c..dc2ef9672 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -376,4 +376,17 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { val predResult1: Array[Array[Float]] = predRDD.collect() assert(testRDD.count() === predResult1.length) } + + test("training with spark parallelism checks disabled") { + import DataUtils._ + val eval = new EvalError() + val trainingRDD = sc.parallelize(Classification.train).map(_.asML) + val testSetDMatrix = new DMatrix(Classification.test.iterator) + val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "timeout_request_workers" -> 0L).toMap + val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, + nWorkers = numWorkers) + assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), + testSetDMatrix) < 0.1) + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala new file mode 100644 index 000000000..e6659c937 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -0,0 +1,57 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.rdd.RDD +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll { + var sc: SparkContext = _ + var numParallelism: Int = _ + + override def beforeAll(): Unit = { + val conf: SparkConf = new SparkConf() + .setMaster("local[*]") + .setAppName("XGBoostSuite") + sc = new SparkContext(conf) + numParallelism = sc.defaultParallelism + } + + test("tracker should not affect execution result") { + val nWorkers = numParallelism + val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) + val tracker = new SparkParallelismTracker(sc, 10000, nWorkers) + val disabledTracker = new SparkParallelismTracker(sc, 0, nWorkers) + assert(tracker.execute(rdd.sum()) == rdd.sum()) + assert(disabledTracker.execute(rdd.sum()) == rdd.sum()) + } + + test("tracker should throw exception if parallelism is not sufficient") { + val nWorkers = numParallelism * 3 + val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) + val tracker = new SparkParallelismTracker(sc, 1000, nWorkers) + intercept[IllegalStateException] { + tracker.execute { + rdd.map { i => + // Test interruption + Thread.sleep(Long.MaxValue) + i + }.sum() + } + } + } +}