[jvm-packages] consider spark.task.cpus when controlling parallelism (#3530)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* consider spark.task.cpus when controlling parallelism

* fix bug

* fix conf setup

* calculate requestedCores within ParallelismController

* enforce spark.task.cpus = 1

* unify unit test case framework

* enable spark ui
This commit is contained in:
Nan Zhu
2018-07-31 06:19:45 -07:00
committed by GitHub
parent 860263f814
commit 6cf97b4eae
4 changed files with 42 additions and 19 deletions

View File

@@ -33,13 +33,14 @@ import scala.concurrent.{Await, Future, TimeoutException}
*
* @param sc The SparkContext object
* @param timeout The maximum time to wait for enough number of workers.
* @param nWorkers nWorkers used in an XGBoost Job
* @param numWorkers nWorkers used in an XGBoost Job
*/
class SparkParallelismTracker(
val sc: SparkContext,
timeout: Long,
nWorkers: Int) {
numWorkers: Int) {
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
private[this] val mapper = new ObjectMapper()
private[this] val logger = LogFactory.getLog("XGBoostSpark")
private[this] val url = sc.uiWebUrl match {
@@ -76,7 +77,7 @@ class SparkParallelismTracker(
}
private[this] def safeExecute[T](body: => T): T = {
val listener = new TaskFailedListener;
val listener = new TaskFailedListener
sc.addSparkListener(listener)
try {
body
@@ -99,10 +100,11 @@ class SparkParallelismTracker(
body
} else {
try {
waitForCondition(numAliveCores >= nWorkers, timeout)
waitForCondition(numAliveCores >= requestedCores, timeout)
} catch {
case _: TimeoutException =>
throw new IllegalStateException(s"Unable to get $nWorkers workers for XGBoost training")
throw new IllegalStateException(s"Unable to get $requestedCores workers for" +
s" XGBoost training")
}
safeExecute(body)
}