[jvm-packages] Scala implementation of the Rabit tracker. (#1612)

* [jvm-packages] Scala implementation of the Rabit tracker.

A Scala implementation of RabitTracker that is interface-interchangable with the
Java implementation, ported from `tracker.py` in the
[dmlc-core project](https://github.com/dmlc/dmlc-core).

* [jvm-packages] Updated Akka dependency in pom.xml.

* Refactored the RabitTracker directory structure.

* Fixed premature stopping of connection handler.

Added a new finite state "AwaitingPortNumber" to explicitly wait for the
worker to send the port, and close the connection. Stopping the actor
prematurely sends a TCP RST to the worker, causing the worker to crash
on AssertionError.

* Added interface IRabitTracker so that user can switch implementations.

* Default timeout duration changes.

* Dependency for Akka tests.

* Removed the main function of RabitTracker.

* A skeleton for testing Akka-based Rabit tracker.

* waitFor() in RabitTracker no longer throws exceptions.

* Completed unit test for the 'start' command of Rabit tracker.

* Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.)

* Fixed the default timeout duration.

* Use Java container to avoid serialization issues due to intermediate wrappers.

* Added tests for Allreduce/model training using Scala Rabit tracker.

* Added spill-over unit test for the Scala Rabit tracker.

* Fixed a typo.

* Overhaul of RabitTracker interface per code review.

  - Removed methods start() waitFor() (no arguments) from IRabitTracker.
  - The timeout in start(timeout) is now worker connection timeout, as tcp
    socket binding timeout is less intuitive.
  - Dropped time unit from start(...) and waitFor(...) methods; the default
    time unit is millisecond.
  - Moved random port number generation into the RabitTrackerHandler.
  - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit.

* More code refactoring and comments.

* Unified timeout constants. Readable tracker status code.

* Add comments to indicate that allReduce is for tests only. Removed all other variants.

* Removed unused imports.

* Simplified signatures of training methods.

 - Moved TrackerConf into parameter map.
 - Changed GeneralParams so that TrackerConf becomes a standalone parameter.
 - Updated test cases accordingly.

* Changed monitoring strategies.

* Reverted monitoring changes.

* Update test case for Rabit AllReduce.

* Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers.

* More comprehensive test cases for exception handling and worker connection timeout.

* Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case.

* Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code.

* Reverted scalastyle-config changes.

* Visibility scope change. Interface tweaks.

* Use match pattern to handle tracker_conf parameter.

* Minor clarification in JNI code.

* Clearer intent in match pattern to suppress warnings.

* Removed Future from constructor. Block in start() and waitFor() instead.

* Revert inadvertent comment changes.

* Removed debugging information.

* Updated test cases that are a bit finicky.

* Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness.
This commit is contained in:
Xin Yin
2016-12-07 09:35:42 -05:00
committed by Nan Zhu
parent 7078c41dad
commit e7fbc8591f
19 changed files with 1910 additions and 25 deletions

View File

@@ -16,11 +16,10 @@
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FSDataInputStream, Path}
@@ -30,6 +29,25 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.{SparkContext, TaskContext}
import scala.concurrent.duration.{Duration, MILLISECONDS}
object TrackerConf {
def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
}
/**
* Rabit tracker configurations.
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*/
case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
@@ -80,7 +98,7 @@ object XGBoost extends Serializable {
private[spark] def buildDistributedBoosters(
trainingSet: RDD[MLLabeledPoint],
xgBoostConfMap: Map[String, Any],
rabitEnv: mutable.Map[String, String],
rabitEnv: java.util.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
import DataUtils._
@@ -92,7 +110,7 @@ object XGBoost extends Serializable {
partitionedTrainingSet.mapPartitions {
trainingSamples =>
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
Rabit.init(rabitEnv)
var booster: Booster = null
if (trainingSamples.hasNext) {
val cacheFileName: String = {
@@ -211,9 +229,21 @@ object XGBoost extends Serializable {
overridedParams
}
private def startTracker(nWorkers: Int): RabitTracker = {
val tracker = new RabitTracker(nWorkers)
require(tracker.start(), "FAULT: Failed to start tracker")
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = trackerConf.trackerImpl match {
case "scala" => new RabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers)
case _ => new PyRabitTracker(nWorkers)
}
val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
trackerConf.workerConnectionTimeout.toMillis
} else {
// 0 == Duration.Inf
0L
}
require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
tracker
}
@@ -227,7 +257,7 @@ object XGBoost extends Serializable {
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
@@ -243,19 +273,26 @@ object XGBoost extends Serializable {
" you have to specify the objective type as classification or regression with a" +
" customized objective function")
}
val tracker = startTracker(nWorkers)
val trackerConf = params.get("tracker_conf") match {
case None => TrackerConf()
case Some(conf: TrackerConf) => conf
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
"instance of TrackerConf.")
}
val tracker = startTracker(nWorkers, trackerConf)
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing)
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
val sparkJobThread = new Thread() {
override def run() {
// force the job
boosters.foreachPartition(() => _)
}
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = tracker.waitFor()
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
isClsTask)

View File

@@ -16,9 +16,12 @@
package ml.dmlc.xgboost4j.scala.spark.params
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.spark.ml.param._
import scala.concurrent.duration.{Duration, NANOSECONDS}
trait GeneralParams extends Params {
/**
@@ -69,7 +72,38 @@ trait GeneralParams extends Params {
*/
val missing = new FloatParam(this, "missing", "the value treated as missing")
/**
* Rabit tracker configurations. The parameter must be provided as an instance of the
* TrackerConf class, which has the following definition:
*
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
* trackerImpl: String)
*
* See below for detailed explanations.
*
* - trackerImpl: Select the implementation of Rabit tracker.
* default: "python"
*
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
* The "scala" version removes Python components, and fully supports timeout settings.
*
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
* default: 0 millisecond (no timeout)
*
* The timeout value should take the time of data loading and pre-processing into account,
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
* perform data transformation before calling XGBoost.train(), so that this timeout truly
* reflects the connection delay. Set a reasonable timeout value to prevent model
* training/testing from hanging indefinitely, possible due to network issues.
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
* Ignored if the tracker implementation is "python".
*/
val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations")
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
useExternalMemory -> false, silent -> 0,
customObj -> null, customEval -> null, missing -> Float.NaN)
customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf()
)
}