[jvm-packages] consider spark.task.cpus when controlling parallelism (#3530)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * consider spark.task.cpus when controlling parallelism * fix bug * fix conf setup * calculate requestedCores within ParallelismController * enforce spark.task.cpus = 1 * unify unit test case framework * enable spark ui
This commit is contained in:
parent
860263f814
commit
6cf97b4eae
@ -222,8 +222,8 @@ object XGBoost extends Serializable {
|
|||||||
checkpointRound: Int =>
|
checkpointRound: Int =>
|
||||||
val tracker = startTracker(nWorkers, trackerConf)
|
val tracker = startTracker(nWorkers, trackerConf)
|
||||||
try {
|
try {
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
|
||||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||||
|
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||||
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
|
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
|
||||||
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
|
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
|
||||||
prevBooster)
|
prevBooster)
|
||||||
|
|||||||
@ -33,13 +33,14 @@ import scala.concurrent.{Await, Future, TimeoutException}
|
|||||||
*
|
*
|
||||||
* @param sc The SparkContext object
|
* @param sc The SparkContext object
|
||||||
* @param timeout The maximum time to wait for enough number of workers.
|
* @param timeout The maximum time to wait for enough number of workers.
|
||||||
* @param nWorkers nWorkers used in an XGBoost Job
|
* @param numWorkers nWorkers used in an XGBoost Job
|
||||||
*/
|
*/
|
||||||
class SparkParallelismTracker(
|
class SparkParallelismTracker(
|
||||||
val sc: SparkContext,
|
val sc: SparkContext,
|
||||||
timeout: Long,
|
timeout: Long,
|
||||||
nWorkers: Int) {
|
numWorkers: Int) {
|
||||||
|
|
||||||
|
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
|
||||||
private[this] val mapper = new ObjectMapper()
|
private[this] val mapper = new ObjectMapper()
|
||||||
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
private[this] val url = sc.uiWebUrl match {
|
private[this] val url = sc.uiWebUrl match {
|
||||||
@ -76,7 +77,7 @@ class SparkParallelismTracker(
|
|||||||
}
|
}
|
||||||
|
|
||||||
private[this] def safeExecute[T](body: => T): T = {
|
private[this] def safeExecute[T](body: => T): T = {
|
||||||
val listener = new TaskFailedListener;
|
val listener = new TaskFailedListener
|
||||||
sc.addSparkListener(listener)
|
sc.addSparkListener(listener)
|
||||||
try {
|
try {
|
||||||
body
|
body
|
||||||
@ -99,10 +100,11 @@ class SparkParallelismTracker(
|
|||||||
body
|
body
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
waitForCondition(numAliveCores >= nWorkers, timeout)
|
waitForCondition(numAliveCores >= requestedCores, timeout)
|
||||||
} catch {
|
} catch {
|
||||||
case _: TimeoutException =>
|
case _: TimeoutException =>
|
||||||
throw new IllegalStateException(s"Unable to get $nWorkers workers for XGBoost training")
|
throw new IllegalStateException(s"Unable to get $requestedCores workers for" +
|
||||||
|
s" XGBoost training")
|
||||||
}
|
}
|
||||||
safeExecute(body)
|
safeExecute(body)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,7 +19,8 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import org.apache.spark.SparkContext
|
|
||||||
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||||
|
|
||||||
@ -37,6 +38,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
|||||||
.appName("XGBoostSuite")
|
.appName("XGBoostSuite")
|
||||||
.config("spark.ui.enabled", false)
|
.config("spark.ui.enabled", false)
|
||||||
.config("spark.driver.memory", "512m")
|
.config("spark.driver.memory", "512m")
|
||||||
|
.config("spark.task.cpus", 1)
|
||||||
|
|
||||||
override def beforeEach(): Unit = getOrCreateSession
|
override def beforeEach(): Unit = getOrCreateSession
|
||||||
|
|
||||||
|
|||||||
@ -16,22 +16,24 @@
|
|||||||
|
|
||||||
package org.apache.spark
|
package org.apache.spark
|
||||||
|
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
import org.apache.spark.sql.SparkSession
|
||||||
|
|
||||||
class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
|
class SparkParallelismTrackerSuite extends FunSuite with PerTest {
|
||||||
var sc: SparkContext = _
|
|
||||||
var numParallelism: Int = _
|
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
val numParallelism: Int = Runtime.getRuntime.availableProcessors()
|
||||||
val conf: SparkConf = new SparkConf()
|
|
||||||
.setMaster("local[*]")
|
|
||||||
.setAppName("XGBoostSuite")
|
|
||||||
sc = new SparkContext(conf)
|
|
||||||
numParallelism = sc.defaultParallelism
|
|
||||||
}
|
|
||||||
|
|
||||||
test("tracker should not affect execution result") {
|
override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
||||||
|
.master("local[*]")
|
||||||
|
.appName("XGBoostSuite")
|
||||||
|
.config("spark.ui.enabled", true)
|
||||||
|
.config("spark.driver.memory", "512m")
|
||||||
|
.config("spark.task.cpus", 1)
|
||||||
|
|
||||||
|
test("tracker should not affect execution result when timeout is not larger than 0") {
|
||||||
val nWorkers = numParallelism
|
val nWorkers = numParallelism
|
||||||
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
||||||
val tracker = new SparkParallelismTracker(sc, 10000, nWorkers)
|
val tracker = new SparkParallelismTracker(sc, 10000, nWorkers)
|
||||||
@ -54,4 +56,21 @@ class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("tracker should throw exception if parallelism is not sufficient with" +
|
||||||
|
" spark.task.cpus larger than 1") {
|
||||||
|
sc.conf.set("spark.task.cpus", "2")
|
||||||
|
val nWorkers = numParallelism
|
||||||
|
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
||||||
|
val tracker = new SparkParallelismTracker(sc, 1000, nWorkers)
|
||||||
|
intercept[IllegalStateException] {
|
||||||
|
tracker.execute {
|
||||||
|
rdd.map { i =>
|
||||||
|
// Test interruption
|
||||||
|
Thread.sleep(Long.MaxValue)
|
||||||
|
i
|
||||||
|
}.sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user