[jvm-packages] consider spark.task.cpus when controlling parallelism (#3530)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* consider spark.task.cpus when controlling parallelism

* fix bug

* fix conf setup

* calculate requestedCores within ParallelismController

* enforce spark.task.cpus = 1

* unify unit test case framework

* enable spark ui
This commit is contained in:
Nan Zhu 2018-07-31 06:19:45 -07:00 committed by GitHub
parent 860263f814
commit 6cf97b4eae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 19 deletions

View File

@ -222,8 +222,8 @@ object XGBoost extends Serializable {
checkpointRound: Int =>
val tracker = startTracker(nWorkers, trackerConf)
try {
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
prevBooster)

View File

@ -33,13 +33,14 @@ import scala.concurrent.{Await, Future, TimeoutException}
*
* @param sc The SparkContext object
* @param timeout The maximum time to wait for enough number of workers.
* @param nWorkers nWorkers used in an XGBoost Job
* @param numWorkers nWorkers used in an XGBoost Job
*/
class SparkParallelismTracker(
val sc: SparkContext,
timeout: Long,
nWorkers: Int) {
numWorkers: Int) {
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
private[this] val mapper = new ObjectMapper()
private[this] val logger = LogFactory.getLog("XGBoostSpark")
private[this] val url = sc.uiWebUrl match {
@ -76,7 +77,7 @@ class SparkParallelismTracker(
}
private[this] def safeExecute[T](body: => T): T = {
val listener = new TaskFailedListener;
val listener = new TaskFailedListener
sc.addSparkListener(listener)
try {
body
@ -99,10 +100,11 @@ class SparkParallelismTracker(
body
} else {
try {
waitForCondition(numAliveCores >= nWorkers, timeout)
waitForCondition(numAliveCores >= requestedCores, timeout)
} catch {
case _: TimeoutException =>
throw new IllegalStateException(s"Unable to get $nWorkers workers for XGBoost training")
throw new IllegalStateException(s"Unable to get $requestedCores workers for" +
s" XGBoost training")
}
safeExecute(body)
}

View File

@ -19,7 +19,8 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.SparkContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
@ -37,6 +38,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
.appName("XGBoostSuite")
.config("spark.ui.enabled", false)
.config("spark.driver.memory", "512m")
.config("spark.task.cpus", 1)
override def beforeEach(): Unit = getOrCreateSession

View File

@ -16,22 +16,24 @@
package org.apache.spark
import org.scalatest.FunSuite
import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest
import org.apache.spark.rdd.RDD
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.sql.SparkSession
class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
var sc: SparkContext = _
var numParallelism: Int = _
class SparkParallelismTrackerSuite extends FunSuite with PerTest {
override def beforeAll(): Unit = {
val conf: SparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("XGBoostSuite")
sc = new SparkContext(conf)
numParallelism = sc.defaultParallelism
}
val numParallelism: Int = Runtime.getRuntime.availableProcessors()
test("tracker should not affect execution result") {
override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
.master("local[*]")
.appName("XGBoostSuite")
.config("spark.ui.enabled", true)
.config("spark.driver.memory", "512m")
.config("spark.task.cpus", 1)
test("tracker should not affect execution result when timeout is not larger than 0") {
val nWorkers = numParallelism
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
val tracker = new SparkParallelismTracker(sc, 10000, nWorkers)
@ -54,4 +56,21 @@ class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
}
}
}
test("tracker should throw exception if parallelism is not sufficient with" +
" spark.task.cpus larger than 1") {
sc.conf.set("spark.task.cpus", "2")
val nWorkers = numParallelism
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
val tracker = new SparkParallelismTracker(sc, 1000, nWorkers)
intercept[IllegalStateException] {
tracker.execute {
rdd.map { i =>
// Test interruption
Thread.sleep(Long.MaxValue)
i
}.sum()
}
}
}
}