diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index c42889964..8d6bc74ab 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -251,17 +251,16 @@ private[spark] trait ParamMapFuncs extends Params { " and grow_histmaker,prune or hist as the updater type") } val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName) - params.find(_.name == name) match { - case None => - case Some(_: DoubleParam) => + params.find(_.name == name).foreach { + case _: DoubleParam => set(name, paramValue.toString.toDouble) - case Some(_: BooleanParam) => + case _: BooleanParam => set(name, paramValue.toString.toBoolean) - case Some(_: IntParam) => + case _: IntParam => set(name, paramValue.toString.toInt) - case Some(_: FloatParam) => + case _: FloatParam => set(name, paramValue.toString.toFloat) - case Some(_: Param[_]) => + case _: Param[_] => set(name, paramValue) } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 84b092cad..bd0ab4d6d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -16,18 +16,8 @@ package org.apache.spark -import java.net.URL -import java.util.concurrent.atomic.AtomicBoolean - import org.apache.commons.logging.LogFactory - import org.apache.spark.scheduler._ -import org.codehaus.jackson.map.ObjectMapper -import scala.collection.JavaConverters._ -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, TimeoutException} -import scala.util.control.ControlThrowable /** * A tracker that ensures enough number of executor cores are alive. @@ -43,39 +33,28 @@ class SparkParallelismTracker( numWorkers: Int) { private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1) - private[this] val mapper = new ObjectMapper() private[this] val logger = LogFactory.getLog("XGBoostSpark") - private[this] val url = sc.uiWebUrl match { - case Some(baseUrl) => new URL(s"$baseUrl/api/v1/applications/${sc.applicationId}/executors") - case _ => null - } private[this] def numAliveCores: Int = { - try { - if (url != null) { - mapper.readTree(url).findValues("totalCores").asScala.map(_.asInt).sum - } else { - Int.MaxValue - } - } catch { - case ex: Throwable => - logger.warn(s"Unable to read total number of alive cores from REST API." + - s"Health Check will be ignored.") - ex.printStackTrace() - Int.MaxValue - } + sc.statusStore.executorList(true).map(_.totalCores).sum } private[this] def waitForCondition( condition: => Boolean, timeout: Long, checkInterval: Long = 100L) = { - val monitor = Future { - while (!condition) { - Thread.sleep(checkInterval) + val waitImpl = new ((Long, Boolean) => Boolean) { + override def apply(waitedTime: Long, status: Boolean): Boolean = status match { + case s if s => true + case _ => waitedTime match { + case t if t < timeout => + Thread.sleep(checkInterval) + apply(t + checkInterval, status = condition) + case _ => false + } } } - Await.ready(monitor, timeout.millis) + waitImpl(0L, condition) } private[this] def safeExecute[T](body: => T): T = { @@ -102,13 +81,9 @@ class SparkParallelismTracker( logger.info("starting training without setting timeout for waiting for resources") body } else { - try { - logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") - waitForCondition(numAliveCores >= requestedCores, timeout) - } catch { - case _: TimeoutException => - throw new IllegalStateException(s"Unable to get $requestedCores workers for" + - s" XGBoost training") + logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") + if (!waitForCondition(numAliveCores >= requestedCores, timeout)) { + throw new IllegalStateException(s"Unable to get $requestedCores cores for XGBoost training") } safeExecute(body) }