From 6cf97b4eae9160109f81730949ee5887404ab7ef Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Tue, 31 Jul 2018 06:19:45 -0700 Subject: [PATCH] [jvm-packages] consider spark.task.cpus when controlling parallelism (#3530) * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * consider spark.task.cpus when controlling parallelism * fix bug * fix conf setup * calculate requestedCores within ParallelismController * enforce spark.task.cpus = 1 * unify unit test case framework * enable spark ui --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 2 +- .../spark/SparkParallelismTracker.scala | 12 +++--- .../dmlc/xgboost4j/scala/spark/PerTest.scala | 4 +- .../spark/SparkParallelismTrackerSuite.scala | 43 +++++++++++++------ 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index a419b12c5..ccc37ebea 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -222,8 +222,8 @@ object XGBoost extends Serializable { checkpointRound: Int => val tracker = startTracker(nWorkers, trackerConf) try { - val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) + val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams, tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing, prevBooster) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 75c5fb484..0f430b219 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -33,13 +33,14 @@ import scala.concurrent.{Await, Future, TimeoutException} * * @param sc The SparkContext object * @param timeout The maximum time to wait for enough number of workers. - * @param nWorkers nWorkers used in an XGBoost Job + * @param numWorkers nWorkers used in an XGBoost Job */ class SparkParallelismTracker( val sc: SparkContext, timeout: Long, - nWorkers: Int) { + numWorkers: Int) { + private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1) private[this] val mapper = new ObjectMapper() private[this] val logger = LogFactory.getLog("XGBoostSpark") private[this] val url = sc.uiWebUrl match { @@ -76,7 +77,7 @@ class SparkParallelismTracker( } private[this] def safeExecute[T](body: => T): T = { - val listener = new TaskFailedListener; + val listener = new TaskFailedListener sc.addSparkListener(listener) try { body @@ -99,10 +100,11 @@ class SparkParallelismTracker( body } else { try { - waitForCondition(numAliveCores >= nWorkers, timeout) + waitForCondition(numAliveCores >= requestedCores, timeout) } catch { case _: TimeoutException => - throw new IllegalStateException(s"Unable to get $nWorkers workers for XGBoost training") + throw new IllegalStateException(s"Unable to get $requestedCores workers for" + + s" XGBoost training") } safeExecute(body) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 9e617aecf..7bba5f342 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -19,7 +19,8 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} -import org.apache.spark.SparkContext + +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ import org.scalatest.{BeforeAndAfterEach, FunSuite} @@ -37,6 +38,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => .appName("XGBoostSuite") .config("spark.ui.enabled", false) .config("spark.driver.memory", "512m") + .config("spark.task.cpus", 1) override def beforeEach(): Unit = getOrCreateSession diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index e6659c937..ba3b15338 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -16,22 +16,24 @@ package org.apache.spark +import org.scalatest.FunSuite +import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest + import org.apache.spark.rdd.RDD -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.apache.spark.sql.SparkSession -class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll { - var sc: SparkContext = _ - var numParallelism: Int = _ +class SparkParallelismTrackerSuite extends FunSuite with PerTest { - override def beforeAll(): Unit = { - val conf: SparkConf = new SparkConf() - .setMaster("local[*]") - .setAppName("XGBoostSuite") - sc = new SparkContext(conf) - numParallelism = sc.defaultParallelism - } + val numParallelism: Int = Runtime.getRuntime.availableProcessors() - test("tracker should not affect execution result") { + override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder() + .master("local[*]") + .appName("XGBoostSuite") + .config("spark.ui.enabled", true) + .config("spark.driver.memory", "512m") + .config("spark.task.cpus", 1) + + test("tracker should not affect execution result when timeout is not larger than 0") { val nWorkers = numParallelism val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) val tracker = new SparkParallelismTracker(sc, 10000, nWorkers) @@ -54,4 +56,21 @@ class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll { } } } + + test("tracker should throw exception if parallelism is not sufficient with" + + " spark.task.cpus larger than 1") { + sc.conf.set("spark.task.cpus", "2") + val nWorkers = numParallelism + val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) + val tracker = new SparkParallelismTracker(sc, 1000, nWorkers) + intercept[IllegalStateException] { + tracker.execute { + rdd.map { i => + // Test interruption + Thread.sleep(Long.MaxValue) + i + }.sum() + } + } + } }