[jvm-packages] consider spark.task.cpus when controlling parallelism (#3530)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * consider spark.task.cpus when controlling parallelism * fix bug * fix conf setup * calculate requestedCores within ParallelismController * enforce spark.task.cpus = 1 * unify unit test case framework * enable spark ui
This commit is contained in:
@@ -222,8 +222,8 @@ object XGBoost extends Serializable {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
try {
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
|
||||
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
|
||||
prevBooster)
|
||||
|
||||
@@ -33,13 +33,14 @@ import scala.concurrent.{Await, Future, TimeoutException}
|
||||
*
|
||||
* @param sc The SparkContext object
|
||||
* @param timeout The maximum time to wait for enough number of workers.
|
||||
* @param nWorkers nWorkers used in an XGBoost Job
|
||||
* @param numWorkers nWorkers used in an XGBoost Job
|
||||
*/
|
||||
class SparkParallelismTracker(
|
||||
val sc: SparkContext,
|
||||
timeout: Long,
|
||||
nWorkers: Int) {
|
||||
numWorkers: Int) {
|
||||
|
||||
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
|
||||
private[this] val mapper = new ObjectMapper()
|
||||
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
||||
private[this] val url = sc.uiWebUrl match {
|
||||
@@ -76,7 +77,7 @@ class SparkParallelismTracker(
|
||||
}
|
||||
|
||||
private[this] def safeExecute[T](body: => T): T = {
|
||||
val listener = new TaskFailedListener;
|
||||
val listener = new TaskFailedListener
|
||||
sc.addSparkListener(listener)
|
||||
try {
|
||||
body
|
||||
@@ -99,10 +100,11 @@ class SparkParallelismTracker(
|
||||
body
|
||||
} else {
|
||||
try {
|
||||
waitForCondition(numAliveCores >= nWorkers, timeout)
|
||||
waitForCondition(numAliveCores >= requestedCores, timeout)
|
||||
} catch {
|
||||
case _: TimeoutException =>
|
||||
throw new IllegalStateException(s"Unable to get $nWorkers workers for XGBoost training")
|
||||
throw new IllegalStateException(s"Unable to get $requestedCores workers for" +
|
||||
s" XGBoost training")
|
||||
}
|
||||
safeExecute(body)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user