[jvm-packages] refine numAliveCores method of SparkParallelismTracker (#4858)
* refine numAliveCores * refine XGBoostToMLlibParams * fix waitForCondition * resolve conflicts * Update SparkParallelismTracker.scala
This commit is contained in:
parent
22209b7b95
commit
277e25797b
@ -251,17 +251,16 @@ private[spark] trait ParamMapFuncs extends Params {
|
||||
" and grow_histmaker,prune or hist as the updater type")
|
||||
}
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
params.find(_.name == name) match {
|
||||
case None =>
|
||||
case Some(_: DoubleParam) =>
|
||||
params.find(_.name == name).foreach {
|
||||
case _: DoubleParam =>
|
||||
set(name, paramValue.toString.toDouble)
|
||||
case Some(_: BooleanParam) =>
|
||||
case _: BooleanParam =>
|
||||
set(name, paramValue.toString.toBoolean)
|
||||
case Some(_: IntParam) =>
|
||||
case _: IntParam =>
|
||||
set(name, paramValue.toString.toInt)
|
||||
case Some(_: FloatParam) =>
|
||||
case _: FloatParam =>
|
||||
set(name, paramValue.toString.toFloat)
|
||||
case Some(_: Param[_]) =>
|
||||
case _: Param[_] =>
|
||||
set(name, paramValue)
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,18 +16,8 @@
|
||||
|
||||
package org.apache.spark
|
||||
|
||||
import java.net.URL
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.scheduler._
|
||||
import org.codehaus.jackson.map.ObjectMapper
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.{Await, Future, TimeoutException}
|
||||
import scala.util.control.ControlThrowable
|
||||
|
||||
/**
|
||||
* A tracker that ensures enough number of executor cores are alive.
|
||||
@ -43,39 +33,28 @@ class SparkParallelismTracker(
|
||||
numWorkers: Int) {
|
||||
|
||||
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
|
||||
private[this] val mapper = new ObjectMapper()
|
||||
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
||||
private[this] val url = sc.uiWebUrl match {
|
||||
case Some(baseUrl) => new URL(s"$baseUrl/api/v1/applications/${sc.applicationId}/executors")
|
||||
case _ => null
|
||||
}
|
||||
|
||||
private[this] def numAliveCores: Int = {
|
||||
try {
|
||||
if (url != null) {
|
||||
mapper.readTree(url).findValues("totalCores").asScala.map(_.asInt).sum
|
||||
} else {
|
||||
Int.MaxValue
|
||||
}
|
||||
} catch {
|
||||
case ex: Throwable =>
|
||||
logger.warn(s"Unable to read total number of alive cores from REST API." +
|
||||
s"Health Check will be ignored.")
|
||||
ex.printStackTrace()
|
||||
Int.MaxValue
|
||||
}
|
||||
sc.statusStore.executorList(true).map(_.totalCores).sum
|
||||
}
|
||||
|
||||
private[this] def waitForCondition(
|
||||
condition: => Boolean,
|
||||
timeout: Long,
|
||||
checkInterval: Long = 100L) = {
|
||||
val monitor = Future {
|
||||
while (!condition) {
|
||||
Thread.sleep(checkInterval)
|
||||
val waitImpl = new ((Long, Boolean) => Boolean) {
|
||||
override def apply(waitedTime: Long, status: Boolean): Boolean = status match {
|
||||
case s if s => true
|
||||
case _ => waitedTime match {
|
||||
case t if t < timeout =>
|
||||
Thread.sleep(checkInterval)
|
||||
apply(t + checkInterval, status = condition)
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
}
|
||||
Await.ready(monitor, timeout.millis)
|
||||
waitImpl(0L, condition)
|
||||
}
|
||||
|
||||
private[this] def safeExecute[T](body: => T): T = {
|
||||
@ -102,13 +81,9 @@ class SparkParallelismTracker(
|
||||
logger.info("starting training without setting timeout for waiting for resources")
|
||||
body
|
||||
} else {
|
||||
try {
|
||||
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
|
||||
waitForCondition(numAliveCores >= requestedCores, timeout)
|
||||
} catch {
|
||||
case _: TimeoutException =>
|
||||
throw new IllegalStateException(s"Unable to get $requestedCores workers for" +
|
||||
s" XGBoost training")
|
||||
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
|
||||
if (!waitForCondition(numAliveCores >= requestedCores, timeout)) {
|
||||
throw new IllegalStateException(s"Unable to get $requestedCores cores for XGBoost training")
|
||||
}
|
||||
safeExecute(body)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user