[jvm-packages] refine numAliveCores method of SparkParallelismTracker (#4858)

* refine numAliveCores

* refine XGBoostToMLlibParams

* fix waitForCondition

* resolve conflicts

* Update SparkParallelismTracker.scala
This commit is contained in:
Xu Xiao 2019-09-20 06:18:29 +08:00 committed by Nan Zhu
parent 22209b7b95
commit 277e25797b
2 changed files with 20 additions and 46 deletions

View File

@ -251,17 +251,16 @@ private[spark] trait ParamMapFuncs extends Params {
" and grow_histmaker,prune or hist as the updater type") " and grow_histmaker,prune or hist as the updater type")
} }
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName) val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
params.find(_.name == name) match { params.find(_.name == name).foreach {
case None => case _: DoubleParam =>
case Some(_: DoubleParam) =>
set(name, paramValue.toString.toDouble) set(name, paramValue.toString.toDouble)
case Some(_: BooleanParam) => case _: BooleanParam =>
set(name, paramValue.toString.toBoolean) set(name, paramValue.toString.toBoolean)
case Some(_: IntParam) => case _: IntParam =>
set(name, paramValue.toString.toInt) set(name, paramValue.toString.toInt)
case Some(_: FloatParam) => case _: FloatParam =>
set(name, paramValue.toString.toFloat) set(name, paramValue.toString.toFloat)
case Some(_: Param[_]) => case _: Param[_] =>
set(name, paramValue) set(name, paramValue)
} }
} }

View File

@ -16,18 +16,8 @@
package org.apache.spark package org.apache.spark
import java.net.URL
import java.util.concurrent.atomic.AtomicBoolean
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.spark.scheduler._ import org.apache.spark.scheduler._
import org.codehaus.jackson.map.ObjectMapper
import scala.collection.JavaConverters._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent.{Await, Future, TimeoutException}
import scala.util.control.ControlThrowable
/** /**
* A tracker that ensures enough number of executor cores are alive. * A tracker that ensures enough number of executor cores are alive.
@ -43,39 +33,28 @@ class SparkParallelismTracker(
numWorkers: Int) { numWorkers: Int) {
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1) private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
private[this] val mapper = new ObjectMapper()
private[this] val logger = LogFactory.getLog("XGBoostSpark") private[this] val logger = LogFactory.getLog("XGBoostSpark")
private[this] val url = sc.uiWebUrl match {
case Some(baseUrl) => new URL(s"$baseUrl/api/v1/applications/${sc.applicationId}/executors")
case _ => null
}
private[this] def numAliveCores: Int = { private[this] def numAliveCores: Int = {
try { sc.statusStore.executorList(true).map(_.totalCores).sum
if (url != null) {
mapper.readTree(url).findValues("totalCores").asScala.map(_.asInt).sum
} else {
Int.MaxValue
}
} catch {
case ex: Throwable =>
logger.warn(s"Unable to read total number of alive cores from REST API." +
s"Health Check will be ignored.")
ex.printStackTrace()
Int.MaxValue
}
} }
private[this] def waitForCondition( private[this] def waitForCondition(
condition: => Boolean, condition: => Boolean,
timeout: Long, timeout: Long,
checkInterval: Long = 100L) = { checkInterval: Long = 100L) = {
val monitor = Future { val waitImpl = new ((Long, Boolean) => Boolean) {
while (!condition) { override def apply(waitedTime: Long, status: Boolean): Boolean = status match {
Thread.sleep(checkInterval) case s if s => true
case _ => waitedTime match {
case t if t < timeout =>
Thread.sleep(checkInterval)
apply(t + checkInterval, status = condition)
case _ => false
}
} }
} }
Await.ready(monitor, timeout.millis) waitImpl(0L, condition)
} }
private[this] def safeExecute[T](body: => T): T = { private[this] def safeExecute[T](body: => T): T = {
@ -102,13 +81,9 @@ class SparkParallelismTracker(
logger.info("starting training without setting timeout for waiting for resources") logger.info("starting training without setting timeout for waiting for resources")
body body
} else { } else {
try { logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") if (!waitForCondition(numAliveCores >= requestedCores, timeout)) {
waitForCondition(numAliveCores >= requestedCores, timeout) throw new IllegalStateException(s"Unable to get $requestedCores cores for XGBoost training")
} catch {
case _: TimeoutException =>
throw new IllegalStateException(s"Unable to get $requestedCores workers for" +
s" XGBoost training")
} }
safeExecute(body) safeExecute(body)
} }