[jvm-packages] refine numAliveCores method of SparkParallelismTracker (#4858)
* refine numAliveCores * refine XGBoostToMLlibParams * fix waitForCondition * resolve conflicts * Update SparkParallelismTracker.scala
This commit is contained in:
parent
22209b7b95
commit
277e25797b
@ -251,17 +251,16 @@ private[spark] trait ParamMapFuncs extends Params {
|
|||||||
" and grow_histmaker,prune or hist as the updater type")
|
" and grow_histmaker,prune or hist as the updater type")
|
||||||
}
|
}
|
||||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||||
params.find(_.name == name) match {
|
params.find(_.name == name).foreach {
|
||||||
case None =>
|
case _: DoubleParam =>
|
||||||
case Some(_: DoubleParam) =>
|
|
||||||
set(name, paramValue.toString.toDouble)
|
set(name, paramValue.toString.toDouble)
|
||||||
case Some(_: BooleanParam) =>
|
case _: BooleanParam =>
|
||||||
set(name, paramValue.toString.toBoolean)
|
set(name, paramValue.toString.toBoolean)
|
||||||
case Some(_: IntParam) =>
|
case _: IntParam =>
|
||||||
set(name, paramValue.toString.toInt)
|
set(name, paramValue.toString.toInt)
|
||||||
case Some(_: FloatParam) =>
|
case _: FloatParam =>
|
||||||
set(name, paramValue.toString.toFloat)
|
set(name, paramValue.toString.toFloat)
|
||||||
case Some(_: Param[_]) =>
|
case _: Param[_] =>
|
||||||
set(name, paramValue)
|
set(name, paramValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,18 +16,8 @@
|
|||||||
|
|
||||||
package org.apache.spark
|
package org.apache.spark
|
||||||
|
|
||||||
import java.net.URL
|
|
||||||
import java.util.concurrent.atomic.AtomicBoolean
|
|
||||||
|
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
import org.apache.spark.scheduler._
|
import org.apache.spark.scheduler._
|
||||||
import org.codehaus.jackson.map.ObjectMapper
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
import scala.concurrent.ExecutionContext.Implicits.global
|
|
||||||
import scala.concurrent.duration._
|
|
||||||
import scala.concurrent.{Await, Future, TimeoutException}
|
|
||||||
import scala.util.control.ControlThrowable
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A tracker that ensures enough number of executor cores are alive.
|
* A tracker that ensures enough number of executor cores are alive.
|
||||||
@ -43,39 +33,28 @@ class SparkParallelismTracker(
|
|||||||
numWorkers: Int) {
|
numWorkers: Int) {
|
||||||
|
|
||||||
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
|
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
|
||||||
private[this] val mapper = new ObjectMapper()
|
|
||||||
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
private[this] val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
private[this] val url = sc.uiWebUrl match {
|
|
||||||
case Some(baseUrl) => new URL(s"$baseUrl/api/v1/applications/${sc.applicationId}/executors")
|
|
||||||
case _ => null
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] def numAliveCores: Int = {
|
private[this] def numAliveCores: Int = {
|
||||||
try {
|
sc.statusStore.executorList(true).map(_.totalCores).sum
|
||||||
if (url != null) {
|
|
||||||
mapper.readTree(url).findValues("totalCores").asScala.map(_.asInt).sum
|
|
||||||
} else {
|
|
||||||
Int.MaxValue
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
case ex: Throwable =>
|
|
||||||
logger.warn(s"Unable to read total number of alive cores from REST API." +
|
|
||||||
s"Health Check will be ignored.")
|
|
||||||
ex.printStackTrace()
|
|
||||||
Int.MaxValue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private[this] def waitForCondition(
|
private[this] def waitForCondition(
|
||||||
condition: => Boolean,
|
condition: => Boolean,
|
||||||
timeout: Long,
|
timeout: Long,
|
||||||
checkInterval: Long = 100L) = {
|
checkInterval: Long = 100L) = {
|
||||||
val monitor = Future {
|
val waitImpl = new ((Long, Boolean) => Boolean) {
|
||||||
while (!condition) {
|
override def apply(waitedTime: Long, status: Boolean): Boolean = status match {
|
||||||
Thread.sleep(checkInterval)
|
case s if s => true
|
||||||
|
case _ => waitedTime match {
|
||||||
|
case t if t < timeout =>
|
||||||
|
Thread.sleep(checkInterval)
|
||||||
|
apply(t + checkInterval, status = condition)
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Await.ready(monitor, timeout.millis)
|
waitImpl(0L, condition)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[this] def safeExecute[T](body: => T): T = {
|
private[this] def safeExecute[T](body: => T): T = {
|
||||||
@ -102,13 +81,9 @@ class SparkParallelismTracker(
|
|||||||
logger.info("starting training without setting timeout for waiting for resources")
|
logger.info("starting training without setting timeout for waiting for resources")
|
||||||
body
|
body
|
||||||
} else {
|
} else {
|
||||||
try {
|
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
|
||||||
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
|
if (!waitForCondition(numAliveCores >= requestedCores, timeout)) {
|
||||||
waitForCondition(numAliveCores >= requestedCores, timeout)
|
throw new IllegalStateException(s"Unable to get $requestedCores cores for XGBoost training")
|
||||||
} catch {
|
|
||||||
case _: TimeoutException =>
|
|
||||||
throw new IllegalStateException(s"Unable to get $requestedCores workers for" +
|
|
||||||
s" XGBoost training")
|
|
||||||
}
|
}
|
||||||
safeExecute(body)
|
safeExecute(body)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user