[jvm-packages] Scala implementation of the Rabit tracker. (#1612)
* [jvm-packages] Scala implementation of the Rabit tracker. A Scala implementation of RabitTracker that is interface-interchangable with the Java implementation, ported from `tracker.py` in the [dmlc-core project](https://github.com/dmlc/dmlc-core). * [jvm-packages] Updated Akka dependency in pom.xml. * Refactored the RabitTracker directory structure. * Fixed premature stopping of connection handler. Added a new finite state "AwaitingPortNumber" to explicitly wait for the worker to send the port, and close the connection. Stopping the actor prematurely sends a TCP RST to the worker, causing the worker to crash on AssertionError. * Added interface IRabitTracker so that user can switch implementations. * Default timeout duration changes. * Dependency for Akka tests. * Removed the main function of RabitTracker. * A skeleton for testing Akka-based Rabit tracker. * waitFor() in RabitTracker no longer throws exceptions. * Completed unit test for the 'start' command of Rabit tracker. * Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.) * Fixed the default timeout duration. * Use Java container to avoid serialization issues due to intermediate wrappers. * Added tests for Allreduce/model training using Scala Rabit tracker. * Added spill-over unit test for the Scala Rabit tracker. * Fixed a typo. * Overhaul of RabitTracker interface per code review. - Removed methods start() waitFor() (no arguments) from IRabitTracker. - The timeout in start(timeout) is now worker connection timeout, as tcp socket binding timeout is less intuitive. - Dropped time unit from start(...) and waitFor(...) methods; the default time unit is millisecond. - Moved random port number generation into the RabitTrackerHandler. - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit. * More code refactoring and comments. * Unified timeout constants. Readable tracker status code. * Add comments to indicate that allReduce is for tests only. Removed all other variants. * Removed unused imports. * Simplified signatures of training methods. - Moved TrackerConf into parameter map. - Changed GeneralParams so that TrackerConf becomes a standalone parameter. - Updated test cases accordingly. * Changed monitoring strategies. * Reverted monitoring changes. * Update test case for Rabit AllReduce. * Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers. * More comprehensive test cases for exception handling and worker connection timeout. * Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case. * Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code. * Reverted scalastyle-config changes. * Visibility scope change. Interface tweaks. * Use match pattern to handle tracker_conf parameter. * Minor clarification in JNI code. * Clearer intent in match pattern to suppress warnings. * Removed Future from constructor. Block in start() and waitFor() instead. * Revert inadvertent comment changes. * Removed debugging information. * Updated test cases that are a bit finicky. * Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness.
This commit is contained in:
parent
7078c41dad
commit
e7fbc8591f
@ -57,7 +57,7 @@ object CustomObjective {
|
||||
case e: XGBoostError =>
|
||||
logger.error(e)
|
||||
null
|
||||
case _ =>
|
||||
case _: Throwable =>
|
||||
null
|
||||
}
|
||||
val grad = new Array[Float](nrow)
|
||||
|
||||
@ -85,7 +85,7 @@ object XGBoost {
|
||||
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
|
||||
XGBoostModel = {
|
||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||
if (tracker.start()) {
|
||||
if (tracker.start(0L)) {
|
||||
dtrain
|
||||
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
|
||||
.reduce((x, y) => x).collect().head
|
||||
|
||||
@ -16,11 +16,10 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||
@ -30,6 +29,25 @@ import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
|
||||
import scala.concurrent.duration.{Duration, MILLISECONDS}
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
|
||||
}
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations.
|
||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||
* Set timeout length to zero to disable timeout.
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (supported by "scala" implementation only.)
|
||||
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||
* in Scala without Python components, and with full support of timeouts.
|
||||
* The Scala implementation is currently experimental, use at your own risk.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
@ -80,7 +98,7 @@ object XGBoost extends Serializable {
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingSet: RDD[MLLabeledPoint],
|
||||
xgBoostConfMap: Map[String, Any],
|
||||
rabitEnv: mutable.Map[String, String],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
@ -92,7 +110,7 @@ object XGBoost extends Serializable {
|
||||
partitionedTrainingSet.mapPartitions {
|
||||
trainingSamples =>
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
Rabit.init(rabitEnv)
|
||||
var booster: Booster = null
|
||||
if (trainingSamples.hasNext) {
|
||||
val cacheFileName: String = {
|
||||
@ -211,9 +229,21 @@ object XGBoost extends Serializable {
|
||||
overridedParams
|
||||
}
|
||||
|
||||
private def startTracker(nWorkers: Int): RabitTracker = {
|
||||
val tracker = new RabitTracker(nWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
||||
case "scala" => new RabitTracker(nWorkers)
|
||||
case "python" => new PyRabitTracker(nWorkers)
|
||||
case _ => new PyRabitTracker(nWorkers)
|
||||
}
|
||||
|
||||
val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
|
||||
trackerConf.workerConnectionTimeout.toMillis
|
||||
} else {
|
||||
// 0 == Duration.Inf
|
||||
0L
|
||||
}
|
||||
|
||||
require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
|
||||
tracker
|
||||
}
|
||||
|
||||
@ -227,7 +257,7 @@ object XGBoost extends Serializable {
|
||||
* @param obj the user-defined objective function, null by default
|
||||
* @param eval the user-defined evaluation function, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
@ -243,19 +273,26 @@ object XGBoost extends Serializable {
|
||||
" you have to specify the objective type as classification or regression with a" +
|
||||
" customized objective function")
|
||||
}
|
||||
val tracker = startTracker(nWorkers)
|
||||
val trackerConf = params.get("tracker_conf") match {
|
||||
case None => TrackerConf()
|
||||
case Some(conf: TrackerConf) => conf
|
||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||
"instance of TrackerConf.")
|
||||
}
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boosters.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = tracker.waitFor()
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
||||
isClsTask)
|
||||
|
||||
@ -16,9 +16,12 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
import scala.concurrent.duration.{Duration, NANOSECONDS}
|
||||
|
||||
trait GeneralParams extends Params {
|
||||
|
||||
/**
|
||||
@ -69,7 +72,38 @@ trait GeneralParams extends Params {
|
||||
*/
|
||||
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||
* TrackerConf class, which has the following definition:
|
||||
*
|
||||
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
|
||||
* trackerImpl: String)
|
||||
*
|
||||
* See below for detailed explanations.
|
||||
*
|
||||
* - trackerImpl: Select the implementation of Rabit tracker.
|
||||
* default: "python"
|
||||
*
|
||||
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
|
||||
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
|
||||
* The "scala" version removes Python components, and fully supports timeout settings.
|
||||
*
|
||||
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
|
||||
* default: 0 millisecond (no timeout)
|
||||
*
|
||||
* The timeout value should take the time of data loading and pre-processing into account,
|
||||
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
|
||||
* perform data transformation before calling XGBoost.train(), so that this timeout truly
|
||||
* reflects the connection delay. Set a reasonable timeout value to prevent model
|
||||
* training/testing from hanging indefinitely, possible due to network issues.
|
||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||
* Ignored if the tracker implementation is "python".
|
||||
*/
|
||||
val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations")
|
||||
|
||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||
useExternalMemory -> false, silent -> 0,
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN)
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||
trackerConf -> TrackerConf()
|
||||
)
|
||||
}
|
||||
|
||||
@ -0,0 +1,169 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
|
||||
class RabitTrackerRobustnessSuite extends FunSuite with Utils {
|
||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||
/*
|
||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
|
||||
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
|
||||
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
|
||||
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
|
||||
that should be avoided.
|
||||
*/
|
||||
val sparkConf = new SparkConf().setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
||||
implicit val sparkContext = new SparkContext(sparkConf)
|
||||
sparkContext.setLogLevel("ERROR")
|
||||
|
||||
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new PyRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
/*
|
||||
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
|
||||
last created worker. A cascading event chain will be triggered once the RuntimeException is
|
||||
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
||||
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
||||
|
||||
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
||||
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
||||
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
||||
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
||||
|
||||
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
||||
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
||||
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
||||
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
||||
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
||||
the entire Spark application to crash.
|
||||
|
||||
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
||||
the exception is thrown at last, ideally after all worker connections have been established.
|
||||
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
||||
process to ensure that pending worker connections are handled.
|
||||
*/
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Rabit.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0) != 0)
|
||||
sparkContext.stop()
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
||||
val sparkConf = new SparkConf().setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
||||
implicit val sparkContext = new SparkContext(sparkConf)
|
||||
sparkContext.setLogLevel("ERROR")
|
||||
|
||||
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Rabit.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
sparkContext.stop()
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's workerConnectionTimeout") {
|
||||
val sparkConf = new SparkConf().setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
||||
implicit val sparkContext = new SparkContext(sparkConf)
|
||||
sparkContext.setLogLevel("ERROR")
|
||||
|
||||
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(500)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
val index = iter.next()
|
||||
// simulate that the first worker cannot connect to tracker due to network issues.
|
||||
if (index != 1) {
|
||||
Rabit.init(trackerEnvs)
|
||||
Thread.sleep(1000)
|
||||
Rabit.shutdown()
|
||||
}
|
||||
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
// should fail due to connection timeout
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
sparkContext.stop()
|
||||
}
|
||||
}
|
||||
@ -17,18 +17,60 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import scala.concurrent.duration._
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors}
|
||||
import org.apache.spark.ml.linalg.{Vectors, Vector => SparkVector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||
val vectorLength = 100
|
||||
val rdd = sc.parallelize(
|
||||
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
|
||||
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
|
||||
|
||||
val rawData = rdd.mapPartitions { iter =>
|
||||
Iterator(iter.toArray)
|
||||
}.collect()
|
||||
|
||||
val maxVec = (0 until vectorLength).toArray.map { j =>
|
||||
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
|
||||
}
|
||||
|
||||
val allReduceResults = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val arr = iter.toArray
|
||||
val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
|
||||
Rabit.shutdown()
|
||||
Iterator(results)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
allReduceResults.foreachPartition(() => _)
|
||||
val byPartitionResults = allReduceResults.collect()
|
||||
assert(byPartitionResults(0).length == vectorLength)
|
||||
collectedAllReduceResults.put(byPartitionResults(0))
|
||||
}
|
||||
}
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == 0)
|
||||
sparkThread.join()
|
||||
|
||||
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||
}
|
||||
|
||||
test("build RDD containing boosters with the specified worker number") {
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
@ -36,7 +78,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new scala.collection.mutable.HashMap[String, String],
|
||||
new java.util.HashMap[String, String](),
|
||||
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === 2)
|
||||
@ -59,6 +101,21 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
cleanExternalCache("XGBoostSuite")
|
||||
}
|
||||
|
||||
test("training with Scala-implemented Rabit tracker") {
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
}
|
||||
|
||||
test("test with dense vectors containing missing value") {
|
||||
def buildDenseRDD(): RDD[LabeledPoint] = {
|
||||
val nrow = 100
|
||||
|
||||
@ -108,5 +108,17 @@
|
||||
<version>4.11</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-actor_${scala.binary.version}</artifactId>
|
||||
<version>2.3.11</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
|
||||
<version>2.3.11</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@ -0,0 +1,43 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Interface for Rabit tracker implementations with three public methods:
|
||||
*
|
||||
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
|
||||
* timeout value (in milliseconds.)
|
||||
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
|
||||
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
|
||||
* milliseconds.
|
||||
*
|
||||
* Each implementation is expected to implement a callback function
|
||||
*
|
||||
* public void uncaughtException(Threat t, Throwable e) { ... }
|
||||
*
|
||||
* to interrupt waitFor() in order to prevent the tracker from hanging indefinitely.
|
||||
*
|
||||
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
|
||||
* brokers connections between workers.
|
||||
*/
|
||||
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||
public enum TrackerStatus {
|
||||
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
||||
|
||||
private int statusCode;
|
||||
|
||||
TrackerStatus(int statusCode) {
|
||||
this.statusCode = statusCode;
|
||||
}
|
||||
|
||||
public int getStatusCode() {
|
||||
return this.statusCode;
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, String> getWorkerEnvs();
|
||||
boolean start(long workerConnectionTimeout);
|
||||
// taskExecutionTimeout has no effect in current version of XGBoost.
|
||||
int waitFor(long taskExecutionTimeout);
|
||||
}
|
||||
@ -2,6 +2,9 @@ package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
@ -22,6 +25,42 @@ public class Rabit {
|
||||
}
|
||||
}
|
||||
|
||||
public enum OpType implements Serializable {
|
||||
MAX(0), MIN(1), SUM(2), BITWISE_OR(3);
|
||||
|
||||
private int op;
|
||||
|
||||
public int getOperand() {
|
||||
return this.op;
|
||||
}
|
||||
|
||||
OpType(int op) {
|
||||
this.op = op;
|
||||
}
|
||||
}
|
||||
|
||||
public enum DataType implements Serializable {
|
||||
CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4),
|
||||
LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8),
|
||||
LONGLONG(8, 8), ULONGLONG(9, 8);
|
||||
|
||||
private int enumOp;
|
||||
private int size;
|
||||
|
||||
public int getEnumOp() {
|
||||
return this.enumOp;
|
||||
}
|
||||
|
||||
public int getSize() {
|
||||
return this.size;
|
||||
}
|
||||
|
||||
DataType(int enumOp, int size) {
|
||||
this.enumOp = enumOp;
|
||||
this.size = size;
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkCall(int ret) throws XGBoostError {
|
||||
if (ret != 0) {
|
||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||
@ -92,4 +131,30 @@ public class Rabit {
|
||||
checkCall(XGBoostJNI.RabitGetWorldSize(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* perform Allreduce on distributed float vectors using operator op.
|
||||
* This implementation of allReduce does not support customized prepare function callback in the
|
||||
* native code, as this function is meant for testing purposes only (to test the Rabit tracker.)
|
||||
*
|
||||
* @param elements local elements on distributed workers.
|
||||
* @param op operator used for Allreduce.
|
||||
* @return All-reduced float elements according to the given operator.
|
||||
*/
|
||||
public static float[] allReduce(float[] elements, OpType op) {
|
||||
DataType dataType = DataType.FLOAT;
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
|
||||
.order(ByteOrder.nativeOrder());
|
||||
|
||||
for (float el : elements) {
|
||||
buffer.putFloat(el);
|
||||
}
|
||||
buffer.flip();
|
||||
|
||||
XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
|
||||
float[] results = new float[elements.length];
|
||||
buffer.asFloatBuffer().get(results);
|
||||
|
||||
return results;
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,15 +5,24 @@ package ml.dmlc.xgboost4j.java;
|
||||
import java.io.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Distributed RabitTracker, need to be started on driver code before running distributed jobs.
|
||||
* Java implementation of the Rabit tracker to coordinate distributed workers.
|
||||
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
|
||||
* start() and waitFor() methods (i.e., the timeout is infinite.)
|
||||
*
|
||||
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
|
||||
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
|
||||
* provides timeout support.
|
||||
*
|
||||
* The tracker must be started on driver node before running distributed jobs.
|
||||
*/
|
||||
public class RabitTracker {
|
||||
public class RabitTracker implements IRabitTracker {
|
||||
// Maybe per tracker logger?
|
||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||
// tracker python file.
|
||||
@ -69,7 +78,6 @@ public class RabitTracker {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public RabitTracker(int numWorkers)
|
||||
throws XGBoostError {
|
||||
if (numWorkers < 1) {
|
||||
@ -78,6 +86,17 @@ public class RabitTracker {
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
public void uncaughtException(Thread t, Throwable e) {
|
||||
logger.error("Uncaught exception thrown by worker:", e);
|
||||
try {
|
||||
Thread.sleep(5000L);
|
||||
} catch (InterruptedException ex) {
|
||||
logger.error(ex);
|
||||
} finally {
|
||||
trackerProcess.get().destroy();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get environments that can be used to pass to worker.
|
||||
* @return The environment settings.
|
||||
@ -126,7 +145,13 @@ public class RabitTracker {
|
||||
}
|
||||
}
|
||||
|
||||
public boolean start() {
|
||||
public boolean start(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for all workers to connect indefinitely, unless " +
|
||||
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
if (startTrackerProcess()) {
|
||||
logger.debug("Tracker started, with env=" + envs.toString());
|
||||
System.out.println("Tracker started, with env=" + envs.toString());
|
||||
@ -142,7 +167,14 @@ public class RabitTracker {
|
||||
}
|
||||
}
|
||||
|
||||
public int waitFor() {
|
||||
public int waitFor(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for either all workers to finish tasks and send " +
|
||||
"shutdown signal, or manual interruptions. " +
|
||||
"Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
try {
|
||||
trackerProcess.get().waitFor();
|
||||
int returnVal = trackerProcess.get().exitValue();
|
||||
@ -153,7 +185,7 @@ public class RabitTracker {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
e.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
return 1;
|
||||
return TrackerStatus.INTERRUPTED.getStatusCode();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
* xgboost JNI functions
|
||||
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
||||
@ -97,4 +99,9 @@ class XGBoostJNI {
|
||||
public final static native int RabitGetRank(int[] out);
|
||||
public final static native int RabitGetWorldSize(int[] out);
|
||||
public final static native int RabitVersionNumber(int[] out);
|
||||
|
||||
// Perform Allreduce operation on data in sendrecvbuf.
|
||||
// This JNI function does not support the callback function for data preparation yet.
|
||||
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
}
|
||||
|
||||
@ -0,0 +1,190 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit
|
||||
|
||||
import java.net.{InetAddress, InetSocketAddress}
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import akka.pattern.ask
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.{Await, Future}
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
/**
|
||||
* Scala implementation of the Rabit tracker interface without Python dependency.
|
||||
* The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
|
||||
* (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
|
||||
* failures.
|
||||
*
|
||||
* Note that this implementation is currently experimental, and should be used at your own risk.
|
||||
*
|
||||
* Example usage:
|
||||
* {{{
|
||||
* import scala.concurrent.duration._
|
||||
*
|
||||
* val tracker = new RabitTracker(32)
|
||||
* // allow up to 10 minutes for all workers to connect to the tracker.
|
||||
* tracker.start(10 minutes)
|
||||
*
|
||||
* /* ...
|
||||
* launching workers in parallel
|
||||
* ...
|
||||
* */
|
||||
*
|
||||
* // wait for worker execution up to 6 hours.
|
||||
* // providing a finite timeout prevents a long-running task from hanging forever in
|
||||
* // catastrophic events, like the loss of an executor during model training.
|
||||
* tracker.waitFor(6 hours)
|
||||
* }}}
|
||||
*
|
||||
* @param numWorkers Number of distributed workers from which the tracker expects connections.
|
||||
* @param port The minimum port number that the tracker binds to.
|
||||
* If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
|
||||
* @param maxPortTrials The maximum number of trials of socket binding, by sequentially
|
||||
* increasing the port number.
|
||||
*/
|
||||
private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
|
||||
maxPortTrials: Int = 1000)
|
||||
extends IRabitTracker {
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
|
||||
|
||||
val system = ActorSystem.create("RabitTracker")
|
||||
val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
|
||||
implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
|
||||
private[this] val tcpBindingTimeout: Duration = 1 minute
|
||||
|
||||
var workerEnvs: Map[String, String] = Map.empty
|
||||
|
||||
override def uncaughtException(t: Thread, e: Throwable): Unit = {
|
||||
handler ? RabitTrackerHandler.InterruptTracker(e)
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the Rabit tracker.
|
||||
*
|
||||
* @param timeout The timeout for awaiting connections from worker nodes.
|
||||
* Note that when used in Spark applications, because all Spark transformations are
|
||||
* lazily executed, the I/O time for loading RDDs/DataFrames from external sources
|
||||
* (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
|
||||
* If the timeout value is too small, the Rabit tracker will likely timeout before workers
|
||||
* establishing connections to the tracker, due to the overhead of loading data.
|
||||
* Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
|
||||
* running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
|
||||
* @return Boolean flag indicating if the Rabit tracker starts successfully.
|
||||
*/
|
||||
private def start(timeout: Duration): Boolean = {
|
||||
handler ? RabitTrackerHandler.StartTracker(
|
||||
new InetSocketAddress(InetAddress.getLocalHost, port.getOrElse(0)), maxPortTrials, timeout)
|
||||
|
||||
// block by waiting for the actor to bind to a port
|
||||
Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
|
||||
.asInstanceOf[Future[Map[String, String]]]) match {
|
||||
case Success(futurePortBound) =>
|
||||
// The success of the Future is contingent on binding to an InetSocketAddress.
|
||||
val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
|
||||
if (isBound) {
|
||||
workerEnvs = Await.result(futurePortBound, 0 nano)
|
||||
}
|
||||
isBound
|
||||
case Failure(ex: Throwable) =>
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the Rabit tracker.
|
||||
*
|
||||
* @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
|
||||
* connections. If a non-positive value is provided, the tracker
|
||||
* waits for incoming worker connections indefinitely.
|
||||
* @return Boolean flag indicating if the Rabit tracker starts successfully.
|
||||
*/
|
||||
def start(connectionTimeoutMillis: Long): Boolean = {
|
||||
if (connectionTimeoutMillis <= 0) {
|
||||
start(Duration.Inf)
|
||||
} else {
|
||||
start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a Map of necessary environment variables to initiate Rabit workers.
|
||||
*
|
||||
* @return HashMap containing tracker information.
|
||||
*/
|
||||
def getWorkerEnvs: java.util.Map[String, String] = {
|
||||
new java.util.HashMap((workerEnvs ++ Map(
|
||||
"DMLC_NUM_WORKER" -> numWorkers.toString,
|
||||
"DMLC_NUM_SERVER" -> "0"
|
||||
)).asJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
|
||||
* This method blocks until timeout or task completion.
|
||||
*
|
||||
* @param atMost the maximum execution time for the workers. By default,
|
||||
* the tracker waits for the workers indefinitely.
|
||||
* @return 0 if the tasks complete successfully, and non-zero otherwise.
|
||||
*/
|
||||
private def waitFor(atMost: Duration): Int = {
|
||||
// request the completion Future from the tracker actor
|
||||
Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
|
||||
.asInstanceOf[Future[Int]]) match {
|
||||
case Success(futureCompleted) =>
|
||||
// wait for all workers to complete synchronously.
|
||||
val statusCode = Try(Await.result(futureCompleted, atMost)) match {
|
||||
case Success(n) if n == numWorkers =>
|
||||
IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
|
||||
case Success(n) if n < numWorkers =>
|
||||
IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
|
||||
case Failure(e) =>
|
||||
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
|
||||
}
|
||||
system.shutdown()
|
||||
statusCode
|
||||
case Failure(ex: Throwable) =>
|
||||
if (!system.isTerminated) {
|
||||
system.shutdown()
|
||||
}
|
||||
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
|
||||
* This method blocks until timeout or task completion.
|
||||
*
|
||||
* @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
|
||||
* non-positive number is given, the tracker waits indefinitely.
|
||||
* @return 0 if the tasks complete successfully, and non-zero otherwise
|
||||
*/
|
||||
def waitFor(atMostMillis: Long): Int = {
|
||||
if (atMostMillis <= 0) {
|
||||
waitFor(Duration.Inf)
|
||||
} else {
|
||||
waitFor(Duration.fromNanos(atMostMillis * 1e6))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,362 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.handler
|
||||
|
||||
import java.net.InetSocketAddress
|
||||
import java.util.UUID
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.collection.mutable
|
||||
import scala.concurrent.{Promise, TimeoutException}
|
||||
import akka.io.{IO, Tcp}
|
||||
import akka.actor._
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
|
||||
|
||||
import scala.util.{Failure, Random, Success, Try}
|
||||
|
||||
/** The Akka actor for handling and coordinating Rabit worker connections.
|
||||
* This is the main actor for handling socket connections, interacting with the synchronous
|
||||
* tracker interface, and resolving tree/ring/parent dependencies between workers.
|
||||
*
|
||||
* @param numWorkers Number of workers to track.
|
||||
*/
|
||||
private[scala] class RabitTrackerHandler(numWorkers: Int)
|
||||
extends Actor with ActorLogging {
|
||||
|
||||
import context.system
|
||||
import RabitWorkerHandler._
|
||||
import RabitTrackerHandler._
|
||||
|
||||
private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
|
||||
private[this] val promisedShutdownWorkers = Promise[Int]()
|
||||
private[this] val tcpManager = IO(Tcp)
|
||||
|
||||
// resolves worker connection dependency.
|
||||
val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
|
||||
|
||||
// workers that have sent "shutdown" signal
|
||||
private[this] val shutdownWorkers = mutable.Set.empty[Int]
|
||||
private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
|
||||
private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
|
||||
private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
|
||||
private[this] var maxPortTrials = 0
|
||||
private[this] var workerConnectionTimeout: Duration = Duration.Inf
|
||||
private[this] var portTrials = 0
|
||||
private[this] val startedWorkers = mutable.Set.empty[Int]
|
||||
|
||||
val linkMap = new LinkMap(numWorkers)
|
||||
|
||||
def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
|
||||
rank match {
|
||||
case r if r >= 0 => Some(r)
|
||||
case _ =>
|
||||
jobId match {
|
||||
case "NULL" => None
|
||||
case jid => jobToRankMap.get(jid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
|
||||
* by the RabitWorkerHandler.
|
||||
*
|
||||
* @param event Generic Tcp.Event
|
||||
*/
|
||||
private def handleTcpEvents(event: Tcp.Event): Unit = event match {
|
||||
case Tcp.Bound(local) =>
|
||||
// expect all workers to connect within timeout
|
||||
log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
|
||||
log.info(s"Worker connection timeout is $workerConnectionTimeout.")
|
||||
|
||||
context.setReceiveTimeout(workerConnectionTimeout)
|
||||
promisedWorkerEnvs.success(Map(
|
||||
"DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
|
||||
"DMLC_TRACKER_PORT" -> local.getPort.toString,
|
||||
// not required because the world size will be communicated to the
|
||||
// worker node after the rank is assigned.
|
||||
"rabit_world_size" -> numWorkers.toString
|
||||
))
|
||||
|
||||
case Tcp.CommandFailed(cmd: Tcp.Bind) =>
|
||||
if (portTrials < maxPortTrials) {
|
||||
portTrials += 1
|
||||
tcpManager ! Tcp.Bind(self,
|
||||
new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
|
||||
backlog = 256)
|
||||
}
|
||||
|
||||
case Tcp.Connected(remote, local) =>
|
||||
log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
|
||||
// revoke timeout if all workers have connected.
|
||||
val workerHandler = context.actorOf(RabitWorkerHandler.props(
|
||||
remote.getAddress.getHostAddress, numWorkers, self, sender()
|
||||
), s"ConnectionHandler-${UUID.randomUUID().toString}")
|
||||
val connection = sender()
|
||||
connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
|
||||
|
||||
actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
|
||||
* to interact with the tracker interface.
|
||||
*
|
||||
* @param trackerMsg control messages sent by RabitTracker class.
|
||||
*/
|
||||
private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
|
||||
trackerMsg match {
|
||||
|
||||
case msg: StartTracker =>
|
||||
maxPortTrials = msg.maxPortTrials
|
||||
workerConnectionTimeout = msg.connectionTimeout
|
||||
|
||||
// if the port number is missing, try binding to a random ephemeral port.
|
||||
if (msg.addr.getPort == 0) {
|
||||
tcpManager ! Tcp.Bind(self,
|
||||
new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
|
||||
backlog = 256)
|
||||
} else {
|
||||
tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
|
||||
}
|
||||
sender() ! true
|
||||
|
||||
case RequestBoundFuture =>
|
||||
sender() ! promisedWorkerEnvs.future
|
||||
|
||||
case RequestCompletionFuture =>
|
||||
sender() ! promisedShutdownWorkers.future
|
||||
|
||||
case InterruptTracker(e) =>
|
||||
log.error(e, "Uncaught exception thrown by worker.")
|
||||
// make sure that waitFor() does not hang indefinitely.
|
||||
promisedShutdownWorkers.failure(e)
|
||||
context.stop(self)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles messages sent by child actors representing connecting Rabit workers, by brokering
|
||||
* messages to the dependency resolver, and processing worker commands.
|
||||
*
|
||||
* @param workerMsg Message sent by RabitWorkerHandler actors.
|
||||
*/
|
||||
private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
|
||||
case req @ RequestAwaitConnWorkers(_, _) =>
|
||||
// since the requester may request to connect to other workers
|
||||
// that have not fully set up, delegate this request to the
|
||||
// dependency resolver which handles the dependencies properly.
|
||||
resolver forward req
|
||||
|
||||
// ---- Rabit worker commands: start/recover/shutdown/print ----
|
||||
case WorkerTrackerPrint(_, _, _, msg) =>
|
||||
log.info(msg.trim)
|
||||
|
||||
case WorkerShutdown(rank, _, _) =>
|
||||
assert(rank >= 0, "Invalid rank.")
|
||||
assert(!shutdownWorkers.contains(rank))
|
||||
shutdownWorkers.add(rank)
|
||||
|
||||
log.info(s"Received shutdown signal from $rank")
|
||||
|
||||
if (shutdownWorkers.size == numWorkers) {
|
||||
promisedShutdownWorkers.success(shutdownWorkers.size)
|
||||
context.stop(self)
|
||||
}
|
||||
|
||||
case WorkerRecover(prevRank, worldSize, jobId) =>
|
||||
assert(prevRank >= 0)
|
||||
sender() ! linkMap.assignRank(prevRank)
|
||||
|
||||
case WorkerStart(rank, worldSize, jobId) =>
|
||||
assert(worldSize == numWorkers || worldSize == -1,
|
||||
s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
|
||||
)
|
||||
|
||||
Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
|
||||
case Success(wkRank) =>
|
||||
if (jobId != "NULL") {
|
||||
jobToRankMap.put(jobId, wkRank)
|
||||
}
|
||||
|
||||
val assignedRank = linkMap.assignRank(wkRank)
|
||||
sender() ! assignedRank
|
||||
resolver ! assignedRank
|
||||
|
||||
log.info("Received start signal from " +
|
||||
s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
|
||||
|
||||
case Failure(ex: IndexOutOfBoundsException) =>
|
||||
// More than worldSize workers have connected, likely due to executor loss.
|
||||
// Since Rabit currently does not support crash recovery (because the Allreduce results
|
||||
// are not cached by the tracker, and because existing workers cannot reestablish
|
||||
// connections to newly spawned executor/worker), the most reasonble action here is to
|
||||
// interrupt the tracker immediate with failure state.
|
||||
log.error("Received invalid start signal from " +
|
||||
s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
|
||||
)
|
||||
promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
|
||||
" received from worker, likely due to executor loss."))
|
||||
|
||||
case Failure(ex) =>
|
||||
log.error(ex, "Unexpected error")
|
||||
promisedShutdownWorkers.failure(ex)
|
||||
}
|
||||
|
||||
|
||||
// ---- Dependency resolving related messages ----
|
||||
case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
|
||||
log.info(s"Worker $host (rank: $rank) has started.")
|
||||
resolver forward msg
|
||||
|
||||
startedWorkers.add(rank)
|
||||
if (startedWorkers.size == numWorkers) {
|
||||
log.info("All workers have started.")
|
||||
}
|
||||
|
||||
case req @ DropFromWaitingList(_) =>
|
||||
// all peer workers in dependency link map have connected;
|
||||
// forward message to resolver to update dependencies.
|
||||
resolver forward req
|
||||
|
||||
case _ =>
|
||||
}
|
||||
|
||||
def receive: Actor.Receive = {
|
||||
case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
|
||||
case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
|
||||
case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
|
||||
|
||||
case akka.actor.ReceiveTimeout =>
|
||||
if (startedWorkers.size < numWorkers) {
|
||||
promisedShutdownWorkers.failure(
|
||||
new TimeoutException("Timed out waiting for workers to connect: " +
|
||||
s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
|
||||
)
|
||||
context.stop(self)
|
||||
}
|
||||
|
||||
context.setReceiveTimeout(Duration.Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the dependency between nodes as they connect to the tracker.
|
||||
* The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
|
||||
* and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
|
||||
* connections by workers, this dependency constraint assumes that a worker node connects first
|
||||
* is likely to finish setup first.
|
||||
*/
|
||||
private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
|
||||
import RabitWorkerHandler._
|
||||
|
||||
context.watch(handler)
|
||||
|
||||
case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
|
||||
|
||||
// worker nodes that have connected, but have not send WorkerStarted message.
|
||||
private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
|
||||
private val startedWorkers = mutable.Set.empty[Int]
|
||||
// worker nodes that have started, and await for connections.
|
||||
private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
|
||||
private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
|
||||
|
||||
def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
|
||||
val connSet = awaitConnWorkers.toMap
|
||||
.filterKeys(k => linkSet.contains(k))
|
||||
AwaitingConnections(connSet, linkSet.size - connSet.size)
|
||||
}
|
||||
|
||||
def receive: Actor.Receive = {
|
||||
// a copy of the AssignedRank message that is also sent to the worker
|
||||
case AssignedRank(rank, tree_neighbors, ring, parent) =>
|
||||
// the workers that the worker of given `rank` depends on:
|
||||
// worker of rank K only depends on workers with rank smaller than K.
|
||||
val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
|
||||
.filter{ r => r != -1 && r < rank}
|
||||
|
||||
log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
|
||||
dependencyMap.put(rank, dependentWorkers)
|
||||
|
||||
case RequestAwaitConnWorkers(rank, toConnectSet) =>
|
||||
val promise = Promise[AwaitingConnections]()
|
||||
|
||||
assert(dependencyMap.contains(rank))
|
||||
|
||||
val updatedDependency = dependencyMap(rank) diff startedWorkers
|
||||
if (updatedDependency.isEmpty) {
|
||||
// all dependencies are satisfied
|
||||
log.debug(s"Rank $rank has all dependencies satisfied.")
|
||||
promise.success(awaitingWorkers(toConnectSet))
|
||||
} else {
|
||||
log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
|
||||
// promise is pending fulfillment due to unresolved dependency
|
||||
pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
|
||||
}
|
||||
|
||||
sender() ! promise.future
|
||||
|
||||
case WorkerStarted(_, started, awaitingAcceptance) =>
|
||||
startedWorkers.add(started)
|
||||
if (awaitingAcceptance > 0) {
|
||||
awaitConnWorkers.put(started, sender())
|
||||
}
|
||||
|
||||
// remove the started rank from all dependencies.
|
||||
dependencyMap.remove(started)
|
||||
dependencyMap.foreach { case (r, dset) =>
|
||||
val updatedDependency = dset diff startedWorkers
|
||||
// fulfill the future if all dependencies are met (started.)
|
||||
if (updatedDependency.isEmpty) {
|
||||
log.debug(s"Rank $r has all dependencies satisfied.")
|
||||
pendingFulfillment.remove(r).map{
|
||||
case Fulfillment(toConnectSet, promise) =>
|
||||
promise.success(awaitingWorkers(toConnectSet))
|
||||
}
|
||||
}
|
||||
|
||||
dependencyMap.update(r, updatedDependency)
|
||||
}
|
||||
|
||||
case DropFromWaitingList(rank) =>
|
||||
assert(awaitConnWorkers.remove(rank).isDefined)
|
||||
|
||||
case Terminated(ref) =>
|
||||
if (ref.equals(handler)) {
|
||||
context.stop(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] object RabitTrackerHandler {
|
||||
// Messages sent by RabitTracker to this RabitTrackerHandler actor
|
||||
trait TrackerControlMessage
|
||||
case object RequestCompletionFuture extends TrackerControlMessage
|
||||
case object RequestBoundFuture extends TrackerControlMessage
|
||||
// Start the Rabit tracker at given socket address awaiting worker connections.
|
||||
// All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
|
||||
// shut down due to timeout.
|
||||
case class StartTracker(addr: InetSocketAddress,
|
||||
maxPortTrials: Int,
|
||||
connectionTimeout: Duration) extends TrackerControlMessage
|
||||
// To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
|
||||
// driver for the distributed training.
|
||||
case class InterruptTracker(e: Throwable) extends TrackerControlMessage
|
||||
|
||||
def props(numWorkers: Int): Props =
|
||||
Props(new RabitTrackerHandler(numWorkers))
|
||||
}
|
||||
@ -0,0 +1,456 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.handler
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
import akka.io.Tcp
|
||||
import akka.actor._
|
||||
import akka.util.ByteString
|
||||
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
|
||||
|
||||
import scala.concurrent.{Await, Future}
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.Try
|
||||
|
||||
/**
|
||||
* Actor to handle socket communication from worker node.
|
||||
* To handle fragmentation in received data, this class acts like a FSM
|
||||
* (finite-state machine) to keep track of the internal states.
|
||||
*
|
||||
* @param host IP address of the remote worker
|
||||
* @param worldSize number of total workers
|
||||
* @param tracker the RabitTrackerHandler actor reference
|
||||
*/
|
||||
private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
|
||||
connection: ActorRef)
|
||||
extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
|
||||
with ActorLogging with Stash {
|
||||
|
||||
import RabitWorkerHandler._
|
||||
import RabitTrackerHelpers._
|
||||
|
||||
private[this] var rank: Int = 0
|
||||
private[this] var port: Int = 0
|
||||
|
||||
// indicate if the connection is transient (like "print" or "shutdown")
|
||||
private[this] var transient: Boolean = false
|
||||
private[this] var peerClosed: Boolean = false
|
||||
|
||||
// number of workers pending acceptance of current worker
|
||||
private[this] var awaitingAcceptance: Int = 0
|
||||
private[this] var neighboringWorkers = Set.empty[Int]
|
||||
|
||||
// TODO: use a single memory allocation to host all buffers,
|
||||
// including the transient ones for writing.
|
||||
private[this] val readBuffer = ByteBuffer.allocate(4096)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
// in case the received message is longer than needed,
|
||||
// stash the spilled over part in this buffer, and send
|
||||
// to self when transition occurs.
|
||||
private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
// when setup is complete, need to notify peer handlers
|
||||
// to reduce the awaiting-connection counter.
|
||||
private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
|
||||
|
||||
private def resetBuffers(): Unit = {
|
||||
readBuffer.clear()
|
||||
if (spillOverBuffer.position() > 0) {
|
||||
spillOverBuffer.flip()
|
||||
self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
|
||||
spillOverBuffer.clear()
|
||||
}
|
||||
}
|
||||
|
||||
private def stashSpillOver(buf: ByteBuffer): Unit = {
|
||||
if (buf.remaining() > 0) spillOverBuffer.put(buf)
|
||||
}
|
||||
|
||||
def getNeighboringWorkers: Set[Int] = neighboringWorkers
|
||||
|
||||
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
|
||||
val rank = buffer.getInt()
|
||||
val worldSize = buffer.getInt()
|
||||
val jobId = buffer.getString
|
||||
|
||||
val command = buffer.getString
|
||||
command match {
|
||||
case "start" => WorkerStart(rank, worldSize, jobId)
|
||||
case "shutdown" =>
|
||||
transient = true
|
||||
WorkerShutdown(rank, worldSize, jobId)
|
||||
case "recover" =>
|
||||
require(rank >= 0, "Invalid rank for recovering worker.")
|
||||
WorkerRecover(rank, worldSize, jobId)
|
||||
case "print" =>
|
||||
transient = true
|
||||
WorkerTrackerPrint(rank, worldSize, jobId, buffer.getString)
|
||||
}
|
||||
}
|
||||
|
||||
startWith(AwaitingHandshake, DataStruct())
|
||||
|
||||
when(AwaitingHandshake) {
|
||||
case Event(Tcp.Received(magic), _) =>
|
||||
assert(magic.length == 4)
|
||||
val purportedMagic = magic.asNativeOrderByteBuffer.getInt
|
||||
assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
|
||||
|
||||
// echo back the magic number
|
||||
connection ! Tcp.Write(magic)
|
||||
goto(AwaitingCommand) using StructTrackerCommand
|
||||
}
|
||||
|
||||
when(AwaitingCommand) {
|
||||
case Event(Tcp.Received(bytes), validator) =>
|
||||
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
|
||||
if (validator.verify(readBuffer)) {
|
||||
readBuffer.flip()
|
||||
tracker ! decodeCommand(readBuffer)
|
||||
stashSpillOver(readBuffer)
|
||||
}
|
||||
|
||||
stay
|
||||
// when rank for a worker is assigned, send encoded rank information
|
||||
// back to worker over Tcp socket.
|
||||
case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
|
||||
log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
|
||||
|
||||
rank = assignedRank
|
||||
// ranks from the ring
|
||||
val ringRanks = List(
|
||||
// ringPrev
|
||||
if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
|
||||
// ringNext
|
||||
if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
|
||||
)
|
||||
|
||||
// update the set of all linked workers to current worker.
|
||||
neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
|
||||
|
||||
connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
|
||||
// to prevent reading before state transition
|
||||
connection ! Tcp.SuspendReading
|
||||
goto(BuildingLinkMap) using StructNodes
|
||||
}
|
||||
|
||||
when(BuildingLinkMap) {
|
||||
case Event(Tcp.Received(bytes), validator) =>
|
||||
bytes.asByteBuffers.foreach { buf =>
|
||||
readBuffer.put(buf)
|
||||
}
|
||||
|
||||
if (validator.verify(readBuffer)) {
|
||||
readBuffer.flip()
|
||||
// for a freshly started worker, numConnected should be 0.
|
||||
val numConnected = readBuffer.getInt()
|
||||
val toConnectSet = neighboringWorkers.diff(
|
||||
(0 until numConnected).map { index => readBuffer.getInt() }.toSet)
|
||||
|
||||
// check which workers are currently awaiting connections
|
||||
tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
|
||||
}
|
||||
stay
|
||||
|
||||
// got a Future from the tracker (resolver) about workers that are
|
||||
// currently awaiting connections (particularly from this node.)
|
||||
case Event(future: Future[_], _) =>
|
||||
// blocks execution until all dependencies for current worker is resolved.
|
||||
Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
|
||||
// numNotReachable is the number of workers that currently
|
||||
// cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
|
||||
// connections from those currently non-reachable nodes in the future.
|
||||
case AwaitingConnections(waitConnNodes, numNotReachable) =>
|
||||
log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
|
||||
val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
|
||||
buf.putInt(waitConnNodes.size).putInt(numNotReachable)
|
||||
buf.flip()
|
||||
|
||||
// cache this message until the final state (SetupComplete)
|
||||
pendingAcknowledgement = Some(AcknowledgeAcceptance(
|
||||
waitConnNodes, numNotReachable))
|
||||
|
||||
connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
|
||||
if (waitConnNodes.isEmpty) {
|
||||
connection ! Tcp.SuspendReading
|
||||
goto(AwaitingErrorCount)
|
||||
}
|
||||
else {
|
||||
waitConnNodes.foreach { case (peerRank, peerRef) =>
|
||||
peerRef ! RequestWorkerHostPort
|
||||
}
|
||||
|
||||
// a countdown for DivulgedHostPort messages.
|
||||
stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
|
||||
}
|
||||
}
|
||||
|
||||
case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
|
||||
val hostBytes = peerHost.getBytes()
|
||||
val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
buffer.putInt(peerHost.length).put(hostBytes)
|
||||
.putInt(peerPort).putInt(peerRank)
|
||||
|
||||
buffer.flip()
|
||||
connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
|
||||
|
||||
if (data.counter == 0) {
|
||||
// to prevent reading before state transition
|
||||
connection ! Tcp.SuspendReading
|
||||
goto(AwaitingErrorCount)
|
||||
}
|
||||
else {
|
||||
stay using data.decrement()
|
||||
}
|
||||
}
|
||||
|
||||
when(AwaitingErrorCount) {
|
||||
case Event(Tcp.Received(numErrors), _) =>
|
||||
val buf = numErrors.asNativeOrderByteBuffer
|
||||
|
||||
buf.getInt match {
|
||||
case 0 =>
|
||||
stashSpillOver(buf)
|
||||
goto(AwaitingPortNumber)
|
||||
case _ =>
|
||||
stashSpillOver(buf)
|
||||
goto(BuildingLinkMap) using StructNodes
|
||||
}
|
||||
}
|
||||
|
||||
when(AwaitingPortNumber) {
|
||||
case Event(Tcp.Received(assignedPort), _) =>
|
||||
assert(assignedPort.length == 4)
|
||||
port = assignedPort.asNativeOrderByteBuffer.getInt
|
||||
log.debug(s"Rank $rank listening @ $host:$port")
|
||||
// wait until the worker closes connection.
|
||||
if (peerClosed) goto(SetupComplete) else stay
|
||||
|
||||
case Event(Tcp.PeerClosed, _) =>
|
||||
peerClosed = true
|
||||
if (port == 0) stay else goto(SetupComplete)
|
||||
}
|
||||
|
||||
when(SetupComplete) {
|
||||
case Event(ReduceWaitCount(count: Int), _) =>
|
||||
awaitingAcceptance -= count
|
||||
// check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
|
||||
if (awaitingAcceptance == 0 && peerClosed) {
|
||||
tracker ! DropFromWaitingList(rank)
|
||||
// no longer needed.
|
||||
context.stop(self)
|
||||
}
|
||||
stay
|
||||
|
||||
case Event(AcknowledgeAcceptance(peers, numBad), _) =>
|
||||
awaitingAcceptance = numBad
|
||||
tracker ! WorkerStarted(host, rank, awaitingAcceptance)
|
||||
peers.values.foreach { peer =>
|
||||
peer ! ReduceWaitCount(1)
|
||||
}
|
||||
|
||||
if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
|
||||
|
||||
stay
|
||||
|
||||
// can only divulge the complete host and port information
|
||||
// when this worker is declared fully connected (otherwise
|
||||
// port information is still missing.)
|
||||
case Event(RequestWorkerHostPort, _) =>
|
||||
sender() ! DivulgedWorkerHostPort(rank, host, port)
|
||||
stay
|
||||
}
|
||||
|
||||
onTransition {
|
||||
// reset buffer when state transitions as data becomes stale
|
||||
case _ -> SetupComplete =>
|
||||
connection ! Tcp.ResumeReading
|
||||
resetBuffers()
|
||||
if (pendingAcknowledgement.isDefined) {
|
||||
self ! pendingAcknowledgement.get
|
||||
}
|
||||
case _ =>
|
||||
connection ! Tcp.ResumeReading
|
||||
resetBuffers()
|
||||
}
|
||||
|
||||
// default message handler
|
||||
whenUnhandled {
|
||||
case Event(Tcp.PeerClosed, _) =>
|
||||
peerClosed = true
|
||||
if (transient) context.stop(self)
|
||||
stay
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] object RabitWorkerHandler {
|
||||
val MAGIC_NUMBER = 0xff99
|
||||
|
||||
// Finite states of this actor, which acts like a FSM.
|
||||
// The following states are defined in order as the FSM progresses.
|
||||
sealed trait State
|
||||
|
||||
// [1] Initial state, awaiting worker to send magic number per protocol.
|
||||
case object AwaitingHandshake extends State
|
||||
// [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
|
||||
case object AwaitingCommand extends State
|
||||
// [3] Brokers connections between workers per ring/tree/parent link map.
|
||||
case object BuildingLinkMap extends State
|
||||
// [4] A transient state in which the worker reports the number of errors in establishing
|
||||
// connections to other peer workers. If no errors, transition to next state.
|
||||
case object AwaitingErrorCount extends State
|
||||
// [5] Awaiting the worker to report its port number for accepting connections from peer workers.
|
||||
// This port number information is later forwarded to linked workers.
|
||||
case object AwaitingPortNumber extends State
|
||||
// [6] Final state after completing the setup with the connecting worker. At this stage, the
|
||||
// worker will have closed the Tcp connection. The actor remains alive to handle messages from
|
||||
// peer actors representing workers with pending setups.
|
||||
case object SetupComplete extends State
|
||||
|
||||
sealed trait DataField
|
||||
case object IntField extends DataField
|
||||
// an integer preceding the actual string
|
||||
case object StringField extends DataField
|
||||
case object IntSeqField extends DataField
|
||||
|
||||
object DataStruct {
|
||||
def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
|
||||
}
|
||||
|
||||
// Internal data pertaining to individual state, used to verify the validity of packets sent by
|
||||
// workers.
|
||||
case class DataStruct(fields: Seq[DataField], counter: Int) {
|
||||
/**
|
||||
* Validate whether the provided buffer is complete (i.e., contains
|
||||
* all data fields specified for this DataStruct.)
|
||||
*
|
||||
* @param buf a byte buffer containing received data.
|
||||
*/
|
||||
def verify(buf: ByteBuffer): Boolean = {
|
||||
if (fields.isEmpty) return true
|
||||
|
||||
val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
|
||||
dupBuf.flip()
|
||||
|
||||
Try(fields.foldLeft(true) {
|
||||
case (complete, field) =>
|
||||
val remBytes = dupBuf.remaining()
|
||||
complete && (remBytes > 0) && (remBytes >= (field match {
|
||||
case IntField =>
|
||||
dupBuf.position(dupBuf.position() + 4)
|
||||
4
|
||||
case StringField =>
|
||||
val strLen = dupBuf.getInt
|
||||
dupBuf.position(dupBuf.position() + strLen)
|
||||
4 + strLen
|
||||
case IntSeqField =>
|
||||
val seqLen = dupBuf.getInt
|
||||
dupBuf.position(dupBuf.position() + seqLen * 4)
|
||||
4 + seqLen * 4
|
||||
}))
|
||||
}).getOrElse(false)
|
||||
}
|
||||
|
||||
def increment(): DataStruct = DataStruct(fields, counter + 1)
|
||||
def decrement(): DataStruct = DataStruct(fields, counter - 1)
|
||||
}
|
||||
|
||||
val StructNodes = DataStruct(List(IntSeqField), 0)
|
||||
val StructTrackerCommand = DataStruct(List(
|
||||
IntField, IntField, StringField, StringField
|
||||
), 0)
|
||||
|
||||
// ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
|
||||
|
||||
// RabitWorkerHandler --> RabitTrackerHandler
|
||||
sealed trait RabitWorkerRequest
|
||||
// RabitWorkerHandler <-- RabitTrackerHandler
|
||||
sealed trait RabitWorkerResponse
|
||||
|
||||
// Representations of decoded worker commands.
|
||||
abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
|
||||
def rank: Int
|
||||
def worldSize: Int
|
||||
def jobId: String
|
||||
|
||||
def encode: ByteString = {
|
||||
val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
|
||||
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
|
||||
.putInt(command.length).put(command.getBytes()).flip()
|
||||
|
||||
ByteString.fromByteBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
|
||||
extends TrackerCommand("start")
|
||||
case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
|
||||
extends TrackerCommand("shutdown")
|
||||
case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
|
||||
extends TrackerCommand("recover")
|
||||
case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
|
||||
extends TrackerCommand("print") {
|
||||
|
||||
override def encode: ByteString = {
|
||||
val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
|
||||
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
|
||||
.putInt(command.length).put(command.getBytes())
|
||||
.putInt(msg.length).put(msg.getBytes()).flip()
|
||||
|
||||
ByteString.fromByteBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// Request to remove the worker of given rank from the list of workers awaiting peer connections.
|
||||
case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
|
||||
// Notify the tracker that the worker of given rank has finished setup and started.
|
||||
case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
|
||||
extends RabitWorkerRequest
|
||||
// Request the set of workers to connect to, according to the LinkMap structure.
|
||||
case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
|
||||
extends RabitWorkerRequest
|
||||
|
||||
// Request, from the tracker, the set of nodes to connect.
|
||||
case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
|
||||
extends RabitWorkerResponse
|
||||
|
||||
// ---- Messages between ConnectionHandler actors ----
|
||||
sealed trait IntraWorkerMessage
|
||||
|
||||
// Notify neighboring workers to decrease the counter of awaiting workers by `count`.
|
||||
case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
|
||||
// Request host and port information from peer ConnectionHandler actors (acting on behave of
|
||||
// connecting workers.) This message will be brokered by RabitTrackerHandler.
|
||||
case object RequestWorkerHostPort extends IntraWorkerMessage
|
||||
// Response to the above request
|
||||
case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
|
||||
// A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
|
||||
case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
|
||||
extends IntraWorkerMessage
|
||||
|
||||
// ---- End of message definitions ----
|
||||
|
||||
def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
|
||||
Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,136 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.util
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
/**
|
||||
* The assigned rank to a connecting Rabit worker, along with the information of the ranks of
|
||||
* its linked peer workers, which are critical to perform Allreduce.
|
||||
* When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
|
||||
* client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
|
||||
* with this class as a message, which is later encoded to byte string, and sent over socket
|
||||
* connection to the worker client.
|
||||
*
|
||||
* @param rank assigned rank (ranked by worker connection order: first worker connecting to the
|
||||
* tracker is assigned rank 0, second with rank 1, etc.)
|
||||
* @param neighbors ranks of neighboring workers in a tree map.
|
||||
* @param ring ranks of neighboring workers in a ring map.
|
||||
* @param parent rank of the parent worker.
|
||||
*/
|
||||
private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
|
||||
ring: (Int, Int), parent: Int) {
|
||||
/**
|
||||
* Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
|
||||
* client.
|
||||
* @param worldSize the number of total distributed workers. Must match `numWorkers` used in
|
||||
* LinkMap.
|
||||
* @return a ByteBuffer containing encoded data.
|
||||
*/
|
||||
def toByteBuffer(worldSize: Int): ByteBuffer = {
|
||||
val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
|
||||
buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
|
||||
// neighbors in tree structure
|
||||
neighbors.foreach { n => buffer.putInt(n) }
|
||||
buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
|
||||
buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
|
||||
|
||||
buffer.flip()
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
private[rabit] class LinkMap(numWorkers: Int) {
|
||||
private def getNeighbors(rank: Int): Seq[Int] = {
|
||||
val rank1 = rank + 1
|
||||
Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
|
||||
r >= 0 && r < numWorkers
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a ring structure that tends to share nodes with the tree.
|
||||
*
|
||||
* @param treeMap
|
||||
* @param parentMap
|
||||
* @param rank
|
||||
* @return Seq[Int] instance starting from rank.
|
||||
*/
|
||||
private def constructShareRing(treeMap: Map[Int, Seq[Int]],
|
||||
parentMap: Map[Int, Int],
|
||||
rank: Int = 0): Seq[Int] = {
|
||||
treeMap(rank).toSet - parentMap(rank) match {
|
||||
case emptySet if emptySet.isEmpty =>
|
||||
List(rank)
|
||||
case connectionSet =>
|
||||
connectionSet.zipWithIndex.foldLeft(List(rank)) {
|
||||
case (ringSeq, (v, cnt)) =>
|
||||
val vConnSeq = constructShareRing(treeMap, parentMap, v)
|
||||
vConnSeq match {
|
||||
case vconn if vconn.size == cnt + 1 =>
|
||||
ringSeq ++ vconn.reverse
|
||||
case vconn =>
|
||||
ringSeq ++ vconn
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Construct a ring connection used to recover local data.
|
||||
*
|
||||
* @param treeMap
|
||||
* @param parentMap
|
||||
*/
|
||||
private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
|
||||
assert(parentMap(0) == -1)
|
||||
|
||||
val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
|
||||
assert(sharedRing.length == treeMap.size)
|
||||
|
||||
(0 until numWorkers).map { r =>
|
||||
val rPrev = (r + numWorkers - 1) % numWorkers
|
||||
val rNext = (r + 1) % numWorkers
|
||||
sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
|
||||
}.toMap
|
||||
}
|
||||
|
||||
private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
|
||||
private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
|
||||
private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
|
||||
val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
|
||||
case ((rmap, k), i) =>
|
||||
val kNext = ringMap_(k)._2
|
||||
(rmap ++ Map(kNext -> (i + 1)), kNext)
|
||||
}._1
|
||||
|
||||
val ringMap = ringMap_.map {
|
||||
case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
|
||||
}
|
||||
val treeMap = treeMap_.map {
|
||||
case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
|
||||
}
|
||||
val parentMap = parentMap_.map {
|
||||
case (k, v) if k == 0 =>
|
||||
rMap_(k) -> -1
|
||||
case (k, v) =>
|
||||
rMap_(k) -> rMap_(v)
|
||||
}
|
||||
|
||||
def assignRank(rank: Int): AssignedRank = {
|
||||
AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,39 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.util
|
||||
|
||||
import java.nio.{ByteOrder, ByteBuffer}
|
||||
import akka.util.ByteString
|
||||
|
||||
private[rabit] object RabitTrackerHelpers {
|
||||
implicit class ByteStringHelplers(bs: ByteString) {
|
||||
// Java by default uses big endian. Enforce native endian so that
|
||||
// the byte order is consistent with the workers.
|
||||
def asNativeOrderByteBuffer: ByteBuffer = {
|
||||
bs.asByteBuffer.order(ByteOrder.nativeOrder())
|
||||
}
|
||||
}
|
||||
|
||||
implicit class ByteBufferHelpers(buf: ByteBuffer) {
|
||||
def getString: String = {
|
||||
val len = buf.getInt()
|
||||
val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
|
||||
buf.get(stringBuffer.array(), 0, len)
|
||||
new String(stringBuffer.array(), "utf-8")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -94,6 +94,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
long max_elem = cbatch.offset[cbatch.size];
|
||||
cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0);
|
||||
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
|
||||
|
||||
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
|
||||
<< "batch.index.length must equal batch.offset.back()";
|
||||
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
|
||||
@ -756,3 +757,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitAllreduce
|
||||
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
|
||||
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
|
||||
RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -303,6 +303,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitAllreduce
|
||||
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -0,0 +1,224 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
import akka.actor.{ActorRef, ActorSystem}
|
||||
import akka.io.Tcp
|
||||
import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe}
|
||||
import akka.util.ByteString
|
||||
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler
|
||||
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._
|
||||
import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap
|
||||
import org.junit.runner.RunWith
|
||||
import org.scalatest.junit.JUnitRunner
|
||||
import org.scalatest.{FlatSpecLike, Matchers}
|
||||
|
||||
import scala.concurrent.Promise
|
||||
|
||||
object RabitTrackerConnectionHandlerTest {
|
||||
def intSeqToByteString(seq: Seq[Int]): ByteString = {
|
||||
val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder())
|
||||
seq.foreach { i => buf.putInt(i) }
|
||||
buf.flip()
|
||||
ByteString.fromByteBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(classOf[JUnitRunner])
|
||||
class RabitTrackerConnectionHandlerTest
|
||||
extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest"))
|
||||
with FlatSpecLike with Matchers with ImplicitSender {
|
||||
|
||||
import RabitTrackerConnectionHandlerTest._
|
||||
|
||||
val magic = intSeqToByteString(List(0xff99))
|
||||
|
||||
"RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in {
|
||||
val trackerProbe = TestProbe()
|
||||
val connProbe = TestProbe()
|
||||
|
||||
val worldSize = 4
|
||||
|
||||
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
|
||||
trackerProbe.ref, connProbe.ref))
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||
|
||||
// send mock magic number
|
||||
fsm ! Tcp.Received(magic)
|
||||
connProbe.expectMsg(Tcp.Write(magic))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||
// ResumeReading should be seen once state transitions
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// send mock tracker command in fragments: the handler should be able to handle it.
|
||||
val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
|
||||
bufRank.putInt(0).putInt(worldSize).flip()
|
||||
|
||||
val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder())
|
||||
bufJobId.putInt(1).put(Array[Byte]('0')).flip()
|
||||
|
||||
val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder())
|
||||
bufCmd.putInt(5).put("start".getBytes()).flip()
|
||||
|
||||
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank))
|
||||
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId))
|
||||
|
||||
// the state should not change for incomplete command data.
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||
|
||||
// send the last fragment, and expect message at tracker actor.
|
||||
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
|
||||
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
|
||||
|
||||
val linkMap = new LinkMap(worldSize)
|
||||
val assignedRank = linkMap.assignRank(0)
|
||||
trackerProbe.reply(assignedRank)
|
||||
|
||||
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
|
||||
assignedRank.toByteBuffer(worldSize)
|
||||
)))
|
||||
|
||||
// reading should be suspended upon transitioning to BuildingLinkMap
|
||||
connProbe.expectMsg(Tcp.SuspendReading)
|
||||
// state should transition with according state data changes.
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// since the connection handler in test has rank 0, it will not have any nodes to connect to.
|
||||
fsm ! Tcp.Received(intSeqToByteString(List(0)))
|
||||
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
|
||||
|
||||
// return mock response to the connection handler
|
||||
val awaitConnPromise = Promise[AwaitingConnections]()
|
||||
awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef],
|
||||
fsm.underlyingActor.getNeighboringWorkers.size
|
||||
))
|
||||
fsm ! awaitConnPromise.future
|
||||
connProbe.expectMsg(Tcp.Write(
|
||||
intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size))
|
||||
))
|
||||
connProbe.expectMsg(Tcp.SuspendReading)
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// send mock error count (0)
|
||||
fsm ! Tcp.Received(intSeqToByteString(List(0)))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events.
|
||||
fsm ! Tcp.PeerClosed
|
||||
// state should not transition
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
|
||||
fsm ! Tcp.Received(intSeqToByteString(List(32768)))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2))
|
||||
|
||||
val handlerStopProbe = TestProbe()
|
||||
handlerStopProbe watch fsm
|
||||
|
||||
// simulate connections from other workers by mocking ReduceWaitCount commands
|
||||
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
|
||||
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
|
||||
trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0))
|
||||
handlerStopProbe.expectTerminated(fsm)
|
||||
|
||||
// all done.
|
||||
}
|
||||
|
||||
it should "forward print command to tracker" in {
|
||||
val trackerProbe = TestProbe()
|
||||
val connProbe = TestProbe()
|
||||
|
||||
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
|
||||
trackerProbe.ref, connProbe.ref))
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||
|
||||
fsm ! Tcp.Received(magic)
|
||||
connProbe.expectMsg(Tcp.Write(magic))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||
// ResumeReading should be seen once state transitions
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
|
||||
fsm ! Tcp.Received(printCmd.encode)
|
||||
|
||||
trackerProbe.expectMsg(printCmd)
|
||||
}
|
||||
|
||||
it should "handle spill-over Tcp data correctly between state transition" in {
|
||||
val trackerProbe = TestProbe()
|
||||
val connProbe = TestProbe()
|
||||
|
||||
val worldSize = 4
|
||||
|
||||
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
|
||||
trackerProbe.ref, connProbe.ref))
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||
|
||||
// send mock magic number
|
||||
fsm ! Tcp.Received(magic)
|
||||
connProbe.expectMsg(Tcp.Write(magic))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||
// ResumeReading should be seen once state transitions
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// send mock tracker command in fragments: the handler should be able to handle it.
|
||||
val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder())
|
||||
bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0'))
|
||||
.putInt(5).put("start".getBytes())
|
||||
// spilled-over data
|
||||
.putInt(0).flip()
|
||||
|
||||
// send data with 4 extra bytes corresponding to the next state.
|
||||
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
|
||||
|
||||
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
|
||||
|
||||
val linkMap = new LinkMap(worldSize)
|
||||
val assignedRank = linkMap.assignRank(0)
|
||||
trackerProbe.reply(assignedRank)
|
||||
|
||||
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
|
||||
assignedRank.toByteBuffer(worldSize)
|
||||
)))
|
||||
|
||||
// reading should be suspended upon transitioning to BuildingLinkMap
|
||||
connProbe.expectMsg(Tcp.SuspendReading)
|
||||
// state should transition with according state data changes.
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// the handler should be able to handle spill-over data, and stash it until state transition.
|
||||
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user