[jvm-packages] consider spark.task.cpus when controlling parallelism (#3530)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * consider spark.task.cpus when controlling parallelism * fix bug * fix conf setup * calculate requestedCores within ParallelismController * enforce spark.task.cpus = 1 * unify unit test case framework * enable spark ui
This commit is contained in:
parent
860263f814
commit
6cf97b4eae
@ -222,8 +222,8 @@ object XGBoost extends Serializable {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
try {
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
|
||||
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
|
||||
prevBooster)
|
||||
|
||||
@ -33,13 +33,14 @@ import scala.concurrent.{Await, Future, TimeoutException}
|
||||
*
|
||||
* @param sc The SparkContext object
|
||||
* @param timeout The maximum time to wait for enough number of workers.
|
||||
* @param nWorkers nWorkers used in an XGBoost Job
|
||||
* @param numWorkers nWorkers used in an XGBoost Job
|
||||
*/
|
||||
class SparkParallelismTracker(
|
||||
val sc: SparkContext,
|
||||
timeout: Long,
|
||||
nWorkers: Int) {
|
||||
numWorkers: Int) {
|
||||
|
||||
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
|
||||
private[this] val mapper = new ObjectMapper()
|
||||
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
||||
private[this] val url = sc.uiWebUrl match {
|
||||
@ -76,7 +77,7 @@ class SparkParallelismTracker(
|
||||
}
|
||||
|
||||
private[this] def safeExecute[T](body: => T): T = {
|
||||
val listener = new TaskFailedListener;
|
||||
val listener = new TaskFailedListener
|
||||
sc.addSparkListener(listener)
|
||||
try {
|
||||
body
|
||||
@ -99,10 +100,11 @@ class SparkParallelismTracker(
|
||||
body
|
||||
} else {
|
||||
try {
|
||||
waitForCondition(numAliveCores >= nWorkers, timeout)
|
||||
waitForCondition(numAliveCores >= requestedCores, timeout)
|
||||
} catch {
|
||||
case _: TimeoutException =>
|
||||
throw new IllegalStateException(s"Unable to get $nWorkers workers for XGBoost training")
|
||||
throw new IllegalStateException(s"Unable to get $requestedCores workers for" +
|
||||
s" XGBoost training")
|
||||
}
|
||||
safeExecute(body)
|
||||
}
|
||||
|
||||
@ -19,7 +19,8 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.spark.SparkContext
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||
|
||||
@ -37,6 +38,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
.appName("XGBoostSuite")
|
||||
.config("spark.ui.enabled", false)
|
||||
.config("spark.driver.memory", "512m")
|
||||
.config("spark.task.cpus", 1)
|
||||
|
||||
override def beforeEach(): Unit = getOrCreateSession
|
||||
|
||||
|
||||
@ -16,22 +16,24 @@
|
||||
|
||||
package org.apache.spark
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
var sc: SparkContext = _
|
||||
var numParallelism: Int = _
|
||||
class SparkParallelismTrackerSuite extends FunSuite with PerTest {
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
val conf: SparkConf = new SparkConf()
|
||||
.setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite")
|
||||
sc = new SparkContext(conf)
|
||||
numParallelism = sc.defaultParallelism
|
||||
}
|
||||
val numParallelism: Int = Runtime.getRuntime.availableProcessors()
|
||||
|
||||
test("tracker should not affect execution result") {
|
||||
override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
||||
.master("local[*]")
|
||||
.appName("XGBoostSuite")
|
||||
.config("spark.ui.enabled", true)
|
||||
.config("spark.driver.memory", "512m")
|
||||
.config("spark.task.cpus", 1)
|
||||
|
||||
test("tracker should not affect execution result when timeout is not larger than 0") {
|
||||
val nWorkers = numParallelism
|
||||
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
||||
val tracker = new SparkParallelismTracker(sc, 10000, nWorkers)
|
||||
@ -54,4 +56,21 @@ class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("tracker should throw exception if parallelism is not sufficient with" +
|
||||
" spark.task.cpus larger than 1") {
|
||||
sc.conf.set("spark.task.cpus", "2")
|
||||
val nWorkers = numParallelism
|
||||
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
||||
val tracker = new SparkParallelismTracker(sc, 1000, nWorkers)
|
||||
intercept[IllegalStateException] {
|
||||
tracker.execute {
|
||||
rdd.map { i =>
|
||||
// Test interruption
|
||||
Thread.sleep(Long.MaxValue)
|
||||
i
|
||||
}.sum()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user