[jvm-packages] Scala implementation of the Rabit tracker. (#1612)
* [jvm-packages] Scala implementation of the Rabit tracker. A Scala implementation of RabitTracker that is interface-interchangable with the Java implementation, ported from `tracker.py` in the [dmlc-core project](https://github.com/dmlc/dmlc-core). * [jvm-packages] Updated Akka dependency in pom.xml. * Refactored the RabitTracker directory structure. * Fixed premature stopping of connection handler. Added a new finite state "AwaitingPortNumber" to explicitly wait for the worker to send the port, and close the connection. Stopping the actor prematurely sends a TCP RST to the worker, causing the worker to crash on AssertionError. * Added interface IRabitTracker so that user can switch implementations. * Default timeout duration changes. * Dependency for Akka tests. * Removed the main function of RabitTracker. * A skeleton for testing Akka-based Rabit tracker. * waitFor() in RabitTracker no longer throws exceptions. * Completed unit test for the 'start' command of Rabit tracker. * Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.) * Fixed the default timeout duration. * Use Java container to avoid serialization issues due to intermediate wrappers. * Added tests for Allreduce/model training using Scala Rabit tracker. * Added spill-over unit test for the Scala Rabit tracker. * Fixed a typo. * Overhaul of RabitTracker interface per code review. - Removed methods start() waitFor() (no arguments) from IRabitTracker. - The timeout in start(timeout) is now worker connection timeout, as tcp socket binding timeout is less intuitive. - Dropped time unit from start(...) and waitFor(...) methods; the default time unit is millisecond. - Moved random port number generation into the RabitTrackerHandler. - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit. * More code refactoring and comments. * Unified timeout constants. Readable tracker status code. * Add comments to indicate that allReduce is for tests only. Removed all other variants. * Removed unused imports. * Simplified signatures of training methods. - Moved TrackerConf into parameter map. - Changed GeneralParams so that TrackerConf becomes a standalone parameter. - Updated test cases accordingly. * Changed monitoring strategies. * Reverted monitoring changes. * Update test case for Rabit AllReduce. * Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers. * More comprehensive test cases for exception handling and worker connection timeout. * Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case. * Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code. * Reverted scalastyle-config changes. * Visibility scope change. Interface tweaks. * Use match pattern to handle tracker_conf parameter. * Minor clarification in JNI code. * Clearer intent in match pattern to suppress warnings. * Removed Future from constructor. Block in start() and waitFor() instead. * Revert inadvertent comment changes. * Removed debugging information. * Updated test cases that are a bit finicky. * Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness.
This commit is contained in:
parent
7078c41dad
commit
e7fbc8591f
@ -57,7 +57,7 @@ object CustomObjective {
|
|||||||
case e: XGBoostError =>
|
case e: XGBoostError =>
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
null
|
null
|
||||||
case _ =>
|
case _: Throwable =>
|
||||||
null
|
null
|
||||||
}
|
}
|
||||||
val grad = new Array[Float](nrow)
|
val grad = new Array[Float](nrow)
|
||||||
|
|||||||
@ -85,7 +85,7 @@ object XGBoost {
|
|||||||
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
|
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
|
||||||
XGBoostModel = {
|
XGBoostModel = {
|
||||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||||
if (tracker.start()) {
|
if (tracker.start(0L)) {
|
||||||
dtrain
|
dtrain
|
||||||
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
|
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
|
||||||
.reduce((x, y) => x).collect().head
|
.reduce((x, y) => x).collect().head
|
||||||
|
|||||||
@ -16,11 +16,10 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||||
@ -30,6 +29,25 @@ import org.apache.spark.rdd.RDD
|
|||||||
import org.apache.spark.sql.Dataset
|
import org.apache.spark.sql.Dataset
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
|
|
||||||
|
import scala.concurrent.duration.{Duration, MILLISECONDS}
|
||||||
|
|
||||||
|
object TrackerConf {
|
||||||
|
def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rabit tracker configurations.
|
||||||
|
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||||
|
* Set timeout length to zero to disable timeout.
|
||||||
|
* Use a finite, non-zero timeout value to prevent tracker from
|
||||||
|
* hanging indefinitely (supported by "scala" implementation only.)
|
||||||
|
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||||
|
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||||
|
* in Scala without Python components, and with full support of timeouts.
|
||||||
|
* The Scala implementation is currently experimental, use at your own risk.
|
||||||
|
*/
|
||||||
|
case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
|
||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
@ -80,7 +98,7 @@ object XGBoost extends Serializable {
|
|||||||
private[spark] def buildDistributedBoosters(
|
private[spark] def buildDistributedBoosters(
|
||||||
trainingSet: RDD[MLLabeledPoint],
|
trainingSet: RDD[MLLabeledPoint],
|
||||||
xgBoostConfMap: Map[String, Any],
|
xgBoostConfMap: Map[String, Any],
|
||||||
rabitEnv: mutable.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||||
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
|
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
@ -92,7 +110,7 @@ object XGBoost extends Serializable {
|
|||||||
partitionedTrainingSet.mapPartitions {
|
partitionedTrainingSet.mapPartitions {
|
||||||
trainingSamples =>
|
trainingSamples =>
|
||||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv)
|
||||||
var booster: Booster = null
|
var booster: Booster = null
|
||||||
if (trainingSamples.hasNext) {
|
if (trainingSamples.hasNext) {
|
||||||
val cacheFileName: String = {
|
val cacheFileName: String = {
|
||||||
@ -211,9 +229,21 @@ object XGBoost extends Serializable {
|
|||||||
overridedParams
|
overridedParams
|
||||||
}
|
}
|
||||||
|
|
||||||
private def startTracker(nWorkers: Int): RabitTracker = {
|
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||||
val tracker = new RabitTracker(nWorkers)
|
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
||||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
case "scala" => new RabitTracker(nWorkers)
|
||||||
|
case "python" => new PyRabitTracker(nWorkers)
|
||||||
|
case _ => new PyRabitTracker(nWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
|
||||||
|
trackerConf.workerConnectionTimeout.toMillis
|
||||||
|
} else {
|
||||||
|
// 0 == Duration.Inf
|
||||||
|
0L
|
||||||
|
}
|
||||||
|
|
||||||
|
require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
|
||||||
tracker
|
tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,7 +257,7 @@ object XGBoost extends Serializable {
|
|||||||
* @param obj the user-defined objective function, null by default
|
* @param obj the user-defined objective function, null by default
|
||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval the user-defined evaluation function, null by default
|
||||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
* @param missing the value represented the missing value in the dataset
|
* @param missing the value represented the missing value in the dataset
|
||||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||||
* @return XGBoostModel when successful training
|
* @return XGBoostModel when successful training
|
||||||
@ -243,19 +273,26 @@ object XGBoost extends Serializable {
|
|||||||
" you have to specify the objective type as classification or regression with a" +
|
" you have to specify the objective type as classification or regression with a" +
|
||||||
" customized objective function")
|
" customized objective function")
|
||||||
}
|
}
|
||||||
val tracker = startTracker(nWorkers)
|
val trackerConf = params.get("tracker_conf") match {
|
||||||
|
case None => TrackerConf()
|
||||||
|
case Some(conf: TrackerConf) => conf
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||||
|
"instance of TrackerConf.")
|
||||||
|
}
|
||||||
|
val tracker = startTracker(nWorkers, trackerConf)
|
||||||
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
||||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||||
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing)
|
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
override def run() {
|
override def run() {
|
||||||
// force the job
|
// force the job
|
||||||
boosters.foreachPartition(() => _)
|
boosters.foreachPartition(() => _)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||||
sparkJobThread.start()
|
sparkJobThread.start()
|
||||||
val isClsTask = isClassificationTask(params)
|
val isClsTask = isClassificationTask(params)
|
||||||
val trackerReturnVal = tracker.waitFor()
|
val trackerReturnVal = tracker.waitFor(0L)
|
||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
||||||
isClsTask)
|
isClsTask)
|
||||||
|
|||||||
@ -16,9 +16,12 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
|
import scala.concurrent.duration.{Duration, NANOSECONDS}
|
||||||
|
|
||||||
trait GeneralParams extends Params {
|
trait GeneralParams extends Params {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -69,7 +72,38 @@ trait GeneralParams extends Params {
|
|||||||
*/
|
*/
|
||||||
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||||
|
* TrackerConf class, which has the following definition:
|
||||||
|
*
|
||||||
|
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
|
||||||
|
* trackerImpl: String)
|
||||||
|
*
|
||||||
|
* See below for detailed explanations.
|
||||||
|
*
|
||||||
|
* - trackerImpl: Select the implementation of Rabit tracker.
|
||||||
|
* default: "python"
|
||||||
|
*
|
||||||
|
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
|
||||||
|
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
|
||||||
|
* The "scala" version removes Python components, and fully supports timeout settings.
|
||||||
|
*
|
||||||
|
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
|
||||||
|
* default: 0 millisecond (no timeout)
|
||||||
|
*
|
||||||
|
* The timeout value should take the time of data loading and pre-processing into account,
|
||||||
|
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
|
||||||
|
* perform data transformation before calling XGBoost.train(), so that this timeout truly
|
||||||
|
* reflects the connection delay. Set a reasonable timeout value to prevent model
|
||||||
|
* training/testing from hanging indefinitely, possible due to network issues.
|
||||||
|
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||||
|
* Ignored if the tracker implementation is "python".
|
||||||
|
*/
|
||||||
|
val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations")
|
||||||
|
|
||||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||||
useExternalMemory -> false, silent -> 0,
|
useExternalMemory -> false, silent -> 0,
|
||||||
customObj -> null, customEval -> null, missing -> Float.NaN)
|
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||||
|
trackerConf -> TrackerConf()
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,169 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
|
||||||
|
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||||
|
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||||
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
|
||||||
|
class RabitTrackerRobustnessSuite extends FunSuite with Utils {
|
||||||
|
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||||
|
/*
|
||||||
|
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||||
|
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
|
||||||
|
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
|
||||||
|
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
|
||||||
|
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
|
||||||
|
that should be avoided.
|
||||||
|
*/
|
||||||
|
val sparkConf = new SparkConf().setMaster("local[*]")
|
||||||
|
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
||||||
|
implicit val sparkContext = new SparkContext(sparkConf)
|
||||||
|
sparkContext.setLogLevel("ERROR")
|
||||||
|
|
||||||
|
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
||||||
|
|
||||||
|
val tracker = new PyRabitTracker(numWorkers)
|
||||||
|
tracker.start(0)
|
||||||
|
val trackerEnvs = tracker.getWorkerEnvs
|
||||||
|
|
||||||
|
val workerCount: Int = numWorkers
|
||||||
|
/*
|
||||||
|
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
|
||||||
|
last created worker. A cascading event chain will be triggered once the RuntimeException is
|
||||||
|
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
||||||
|
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
||||||
|
|
||||||
|
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
||||||
|
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
||||||
|
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
||||||
|
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
||||||
|
|
||||||
|
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
||||||
|
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
||||||
|
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
||||||
|
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
||||||
|
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
||||||
|
the entire Spark application to crash.
|
||||||
|
|
||||||
|
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
||||||
|
the exception is thrown at last, ideally after all worker connections have been established.
|
||||||
|
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
||||||
|
process to ensure that pending worker connections are handled.
|
||||||
|
*/
|
||||||
|
val dummyTasks = rdd.mapPartitions { iter =>
|
||||||
|
Rabit.init(trackerEnvs)
|
||||||
|
val index = iter.next()
|
||||||
|
Thread.sleep(100 + index * 10)
|
||||||
|
if (index == workerCount) {
|
||||||
|
// kill the worker by throwing an exception
|
||||||
|
throw new RuntimeException("Worker exception.")
|
||||||
|
}
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(index)
|
||||||
|
}.cache()
|
||||||
|
|
||||||
|
val sparkThread = new Thread() {
|
||||||
|
override def run(): Unit = {
|
||||||
|
// forces a Spark job.
|
||||||
|
dummyTasks.foreachPartition(() => _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||||
|
sparkThread.start()
|
||||||
|
assert(tracker.waitFor(0) != 0)
|
||||||
|
sparkContext.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
||||||
|
val sparkConf = new SparkConf().setMaster("local[*]")
|
||||||
|
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
||||||
|
implicit val sparkContext = new SparkContext(sparkConf)
|
||||||
|
sparkContext.setLogLevel("ERROR")
|
||||||
|
|
||||||
|
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
||||||
|
|
||||||
|
val tracker = new ScalaRabitTracker(numWorkers)
|
||||||
|
tracker.start(0)
|
||||||
|
val trackerEnvs = tracker.getWorkerEnvs
|
||||||
|
|
||||||
|
val workerCount: Int = numWorkers
|
||||||
|
val dummyTasks = rdd.mapPartitions { iter =>
|
||||||
|
Rabit.init(trackerEnvs)
|
||||||
|
val index = iter.next()
|
||||||
|
Thread.sleep(100 + index * 10)
|
||||||
|
if (index == workerCount) {
|
||||||
|
// kill the worker by throwing an exception
|
||||||
|
throw new RuntimeException("Worker exception.")
|
||||||
|
}
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(index)
|
||||||
|
}.cache()
|
||||||
|
|
||||||
|
val sparkThread = new Thread() {
|
||||||
|
override def run(): Unit = {
|
||||||
|
// forces a Spark job.
|
||||||
|
dummyTasks.foreachPartition(() => _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||||
|
sparkThread.start()
|
||||||
|
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||||
|
sparkContext.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test Scala RabitTracker's workerConnectionTimeout") {
|
||||||
|
val sparkConf = new SparkConf().setMaster("local[*]")
|
||||||
|
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
||||||
|
implicit val sparkContext = new SparkContext(sparkConf)
|
||||||
|
sparkContext.setLogLevel("ERROR")
|
||||||
|
|
||||||
|
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
||||||
|
|
||||||
|
val tracker = new ScalaRabitTracker(numWorkers)
|
||||||
|
tracker.start(500)
|
||||||
|
val trackerEnvs = tracker.getWorkerEnvs
|
||||||
|
|
||||||
|
val dummyTasks = rdd.mapPartitions { iter =>
|
||||||
|
val index = iter.next()
|
||||||
|
// simulate that the first worker cannot connect to tracker due to network issues.
|
||||||
|
if (index != 1) {
|
||||||
|
Rabit.init(trackerEnvs)
|
||||||
|
Thread.sleep(1000)
|
||||||
|
Rabit.shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
Iterator(index)
|
||||||
|
}.cache()
|
||||||
|
|
||||||
|
val sparkThread = new Thread() {
|
||||||
|
override def run(): Unit = {
|
||||||
|
// forces a Spark job.
|
||||||
|
dummyTasks.foreachPartition(() => _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||||
|
sparkThread.start()
|
||||||
|
// should fail due to connection timeout
|
||||||
|
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||||
|
sparkContext.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -17,18 +17,60 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
|
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
|
||||||
|
|
||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
import scala.concurrent.duration._
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors}
|
import org.apache.spark.ml.linalg.{Vectors, Vector => SparkVector}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||||
|
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||||
|
val vectorLength = 100
|
||||||
|
val rdd = sc.parallelize(
|
||||||
|
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
|
||||||
|
|
||||||
|
val tracker = new RabitTracker(numWorkers)
|
||||||
|
tracker.start(0)
|
||||||
|
val trackerEnvs = tracker.getWorkerEnvs
|
||||||
|
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
|
||||||
|
|
||||||
|
val rawData = rdd.mapPartitions { iter =>
|
||||||
|
Iterator(iter.toArray)
|
||||||
|
}.collect()
|
||||||
|
|
||||||
|
val maxVec = (0 until vectorLength).toArray.map { j =>
|
||||||
|
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
|
||||||
|
}
|
||||||
|
|
||||||
|
val allReduceResults = rdd.mapPartitions { iter =>
|
||||||
|
Rabit.init(trackerEnvs)
|
||||||
|
val arr = iter.toArray
|
||||||
|
val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(results)
|
||||||
|
}.cache()
|
||||||
|
|
||||||
|
val sparkThread = new Thread() {
|
||||||
|
override def run(): Unit = {
|
||||||
|
allReduceResults.foreachPartition(() => _)
|
||||||
|
val byPartitionResults = allReduceResults.collect()
|
||||||
|
assert(byPartitionResults(0).length == vectorLength)
|
||||||
|
collectedAllReduceResults.put(byPartitionResults(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sparkThread.start()
|
||||||
|
assert(tracker.waitFor(0L) == 0)
|
||||||
|
sparkThread.join()
|
||||||
|
|
||||||
|
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||||
|
}
|
||||||
|
|
||||||
test("build RDD containing boosters with the specified worker number") {
|
test("build RDD containing boosters with the specified worker number") {
|
||||||
val trainingRDD = buildTrainingRDD(sc)
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
@ -36,7 +78,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
trainingRDD,
|
trainingRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
new scala.collection.mutable.HashMap[String, String],
|
new java.util.HashMap[String, String](),
|
||||||
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
|
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
|
||||||
val boosterCount = boosterRDD.count()
|
val boosterCount = boosterRDD.count()
|
||||||
assert(boosterCount === 2)
|
assert(boosterCount === 2)
|
||||||
@ -59,6 +101,21 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
cleanExternalCache("XGBoostSuite")
|
cleanExternalCache("XGBoostSuite")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("training with Scala-implemented Rabit tracker") {
|
||||||
|
val eval = new EvalError()
|
||||||
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic",
|
||||||
|
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
|
||||||
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
|
nWorkers = numWorkers, useExternalMemory = true)
|
||||||
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix) < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
test("test with dense vectors containing missing value") {
|
test("test with dense vectors containing missing value") {
|
||||||
def buildDenseRDD(): RDD[LabeledPoint] = {
|
def buildDenseRDD(): RDD[LabeledPoint] = {
|
||||||
val nrow = 100
|
val nrow = 100
|
||||||
|
|||||||
@ -108,5 +108,17 @@
|
|||||||
<version>4.11</version>
|
<version>4.11</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.typesafe.akka</groupId>
|
||||||
|
<artifactId>akka-actor_${scala.binary.version}</artifactId>
|
||||||
|
<version>2.3.11</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.typesafe.akka</groupId>
|
||||||
|
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
|
||||||
|
<version>2.3.11</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@ -0,0 +1,43 @@
|
|||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for Rabit tracker implementations with three public methods:
|
||||||
|
*
|
||||||
|
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
|
||||||
|
* timeout value (in milliseconds.)
|
||||||
|
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
|
||||||
|
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
|
||||||
|
* milliseconds.
|
||||||
|
*
|
||||||
|
* Each implementation is expected to implement a callback function
|
||||||
|
*
|
||||||
|
* public void uncaughtException(Threat t, Throwable e) { ... }
|
||||||
|
*
|
||||||
|
* to interrupt waitFor() in order to prevent the tracker from hanging indefinitely.
|
||||||
|
*
|
||||||
|
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
|
||||||
|
* brokers connections between workers.
|
||||||
|
*/
|
||||||
|
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||||
|
public enum TrackerStatus {
|
||||||
|
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
||||||
|
|
||||||
|
private int statusCode;
|
||||||
|
|
||||||
|
TrackerStatus(int statusCode) {
|
||||||
|
this.statusCode = statusCode;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getStatusCode() {
|
||||||
|
return this.statusCode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, String> getWorkerEnvs();
|
||||||
|
boolean start(long workerConnectionTimeout);
|
||||||
|
// taskExecutionTimeout has no effect in current version of XGBoost.
|
||||||
|
int waitFor(long taskExecutionTimeout);
|
||||||
|
}
|
||||||
@ -2,6 +2,9 @@ package ml.dmlc.xgboost4j.java;
|
|||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
@ -22,6 +25,42 @@ public class Rabit {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public enum OpType implements Serializable {
|
||||||
|
MAX(0), MIN(1), SUM(2), BITWISE_OR(3);
|
||||||
|
|
||||||
|
private int op;
|
||||||
|
|
||||||
|
public int getOperand() {
|
||||||
|
return this.op;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpType(int op) {
|
||||||
|
this.op = op;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum DataType implements Serializable {
|
||||||
|
CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4),
|
||||||
|
LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8),
|
||||||
|
LONGLONG(8, 8), ULONGLONG(9, 8);
|
||||||
|
|
||||||
|
private int enumOp;
|
||||||
|
private int size;
|
||||||
|
|
||||||
|
public int getEnumOp() {
|
||||||
|
return this.enumOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getSize() {
|
||||||
|
return this.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType(int enumOp, int size) {
|
||||||
|
this.enumOp = enumOp;
|
||||||
|
this.size = size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static void checkCall(int ret) throws XGBoostError {
|
private static void checkCall(int ret) throws XGBoostError {
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||||
@ -92,4 +131,30 @@ public class Rabit {
|
|||||||
checkCall(XGBoostJNI.RabitGetWorldSize(out));
|
checkCall(XGBoostJNI.RabitGetWorldSize(out));
|
||||||
return out[0];
|
return out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* perform Allreduce on distributed float vectors using operator op.
|
||||||
|
* This implementation of allReduce does not support customized prepare function callback in the
|
||||||
|
* native code, as this function is meant for testing purposes only (to test the Rabit tracker.)
|
||||||
|
*
|
||||||
|
* @param elements local elements on distributed workers.
|
||||||
|
* @param op operator used for Allreduce.
|
||||||
|
* @return All-reduced float elements according to the given operator.
|
||||||
|
*/
|
||||||
|
public static float[] allReduce(float[] elements, OpType op) {
|
||||||
|
DataType dataType = DataType.FLOAT;
|
||||||
|
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
|
||||||
|
.order(ByteOrder.nativeOrder());
|
||||||
|
|
||||||
|
for (float el : elements) {
|
||||||
|
buffer.putFloat(el);
|
||||||
|
}
|
||||||
|
buffer.flip();
|
||||||
|
|
||||||
|
XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
|
||||||
|
float[] results = new float[elements.length];
|
||||||
|
buffer.asFloatBuffer().get(results);
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,15 +5,24 @@ package ml.dmlc.xgboost4j.java;
|
|||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Distributed RabitTracker, need to be started on driver code before running distributed jobs.
|
* Java implementation of the Rabit tracker to coordinate distributed workers.
|
||||||
|
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
|
||||||
|
* start() and waitFor() methods (i.e., the timeout is infinite.)
|
||||||
|
*
|
||||||
|
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
|
||||||
|
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
|
||||||
|
* provides timeout support.
|
||||||
|
*
|
||||||
|
* The tracker must be started on driver node before running distributed jobs.
|
||||||
*/
|
*/
|
||||||
public class RabitTracker {
|
public class RabitTracker implements IRabitTracker {
|
||||||
// Maybe per tracker logger?
|
// Maybe per tracker logger?
|
||||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||||
// tracker python file.
|
// tracker python file.
|
||||||
@ -69,7 +78,6 @@ public class RabitTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public RabitTracker(int numWorkers)
|
public RabitTracker(int numWorkers)
|
||||||
throws XGBoostError {
|
throws XGBoostError {
|
||||||
if (numWorkers < 1) {
|
if (numWorkers < 1) {
|
||||||
@ -78,6 +86,17 @@ public class RabitTracker {
|
|||||||
this.numWorkers = numWorkers;
|
this.numWorkers = numWorkers;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void uncaughtException(Thread t, Throwable e) {
|
||||||
|
logger.error("Uncaught exception thrown by worker:", e);
|
||||||
|
try {
|
||||||
|
Thread.sleep(5000L);
|
||||||
|
} catch (InterruptedException ex) {
|
||||||
|
logger.error(ex);
|
||||||
|
} finally {
|
||||||
|
trackerProcess.get().destroy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get environments that can be used to pass to worker.
|
* Get environments that can be used to pass to worker.
|
||||||
* @return The environment settings.
|
* @return The environment settings.
|
||||||
@ -126,7 +145,13 @@ public class RabitTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean start() {
|
public boolean start(long timeout) {
|
||||||
|
if (timeout > 0L) {
|
||||||
|
logger.warn("Python RabitTracker does not support timeout. " +
|
||||||
|
"The tracker will wait for all workers to connect indefinitely, unless " +
|
||||||
|
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
|
||||||
|
}
|
||||||
|
|
||||||
if (startTrackerProcess()) {
|
if (startTrackerProcess()) {
|
||||||
logger.debug("Tracker started, with env=" + envs.toString());
|
logger.debug("Tracker started, with env=" + envs.toString());
|
||||||
System.out.println("Tracker started, with env=" + envs.toString());
|
System.out.println("Tracker started, with env=" + envs.toString());
|
||||||
@ -142,7 +167,14 @@ public class RabitTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public int waitFor() {
|
public int waitFor(long timeout) {
|
||||||
|
if (timeout > 0L) {
|
||||||
|
logger.warn("Python RabitTracker does not support timeout. " +
|
||||||
|
"The tracker will wait for either all workers to finish tasks and send " +
|
||||||
|
"shutdown signal, or manual interruptions. " +
|
||||||
|
"Use the Scala RabitTracker for timeout support.");
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
trackerProcess.get().waitFor();
|
trackerProcess.get().waitFor();
|
||||||
int returnVal = trackerProcess.get().exitValue();
|
int returnVal = trackerProcess.get().exitValue();
|
||||||
@ -153,7 +185,7 @@ public class RabitTracker {
|
|||||||
// we should not get here as RabitTracker is accessed in the main thread
|
// we should not get here as RabitTracker is accessed in the main thread
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||||
return 1;
|
return TrackerStatus.INTERRUPTED.getStatusCode();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,6 +16,8 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* xgboost JNI functions
|
* xgboost JNI functions
|
||||||
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
||||||
@ -97,4 +99,9 @@ class XGBoostJNI {
|
|||||||
public final static native int RabitGetRank(int[] out);
|
public final static native int RabitGetRank(int[] out);
|
||||||
public final static native int RabitGetWorldSize(int[] out);
|
public final static native int RabitGetWorldSize(int[] out);
|
||||||
public final static native int RabitVersionNumber(int[] out);
|
public final static native int RabitVersionNumber(int[] out);
|
||||||
|
|
||||||
|
// Perform Allreduce operation on data in sendrecvbuf.
|
||||||
|
// This JNI function does not support the callback function for data preparation yet.
|
||||||
|
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||||
|
int enum_dtype, int enum_op);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,190 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.rabit
|
||||||
|
|
||||||
|
import java.net.{InetAddress, InetSocketAddress}
|
||||||
|
|
||||||
|
import akka.actor.ActorSystem
|
||||||
|
import akka.pattern.ask
|
||||||
|
import ml.dmlc.xgboost4j.java.IRabitTracker
|
||||||
|
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
|
||||||
|
|
||||||
|
import scala.concurrent.duration._
|
||||||
|
import scala.concurrent.{Await, Future}
|
||||||
|
import scala.util.{Failure, Success, Try}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scala implementation of the Rabit tracker interface without Python dependency.
|
||||||
|
* The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
|
||||||
|
* (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
|
||||||
|
* failures.
|
||||||
|
*
|
||||||
|
* Note that this implementation is currently experimental, and should be used at your own risk.
|
||||||
|
*
|
||||||
|
* Example usage:
|
||||||
|
* {{{
|
||||||
|
* import scala.concurrent.duration._
|
||||||
|
*
|
||||||
|
* val tracker = new RabitTracker(32)
|
||||||
|
* // allow up to 10 minutes for all workers to connect to the tracker.
|
||||||
|
* tracker.start(10 minutes)
|
||||||
|
*
|
||||||
|
* /* ...
|
||||||
|
* launching workers in parallel
|
||||||
|
* ...
|
||||||
|
* */
|
||||||
|
*
|
||||||
|
* // wait for worker execution up to 6 hours.
|
||||||
|
* // providing a finite timeout prevents a long-running task from hanging forever in
|
||||||
|
* // catastrophic events, like the loss of an executor during model training.
|
||||||
|
* tracker.waitFor(6 hours)
|
||||||
|
* }}}
|
||||||
|
*
|
||||||
|
* @param numWorkers Number of distributed workers from which the tracker expects connections.
|
||||||
|
* @param port The minimum port number that the tracker binds to.
|
||||||
|
* If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
|
||||||
|
* @param maxPortTrials The maximum number of trials of socket binding, by sequentially
|
||||||
|
* increasing the port number.
|
||||||
|
*/
|
||||||
|
private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
|
||||||
|
maxPortTrials: Int = 1000)
|
||||||
|
extends IRabitTracker {
|
||||||
|
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
|
||||||
|
|
||||||
|
val system = ActorSystem.create("RabitTracker")
|
||||||
|
val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
|
||||||
|
implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
|
||||||
|
private[this] val tcpBindingTimeout: Duration = 1 minute
|
||||||
|
|
||||||
|
var workerEnvs: Map[String, String] = Map.empty
|
||||||
|
|
||||||
|
override def uncaughtException(t: Thread, e: Throwable): Unit = {
|
||||||
|
handler ? RabitTrackerHandler.InterruptTracker(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Start the Rabit tracker.
|
||||||
|
*
|
||||||
|
* @param timeout The timeout for awaiting connections from worker nodes.
|
||||||
|
* Note that when used in Spark applications, because all Spark transformations are
|
||||||
|
* lazily executed, the I/O time for loading RDDs/DataFrames from external sources
|
||||||
|
* (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
|
||||||
|
* If the timeout value is too small, the Rabit tracker will likely timeout before workers
|
||||||
|
* establishing connections to the tracker, due to the overhead of loading data.
|
||||||
|
* Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
|
||||||
|
* running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
|
||||||
|
* @return Boolean flag indicating if the Rabit tracker starts successfully.
|
||||||
|
*/
|
||||||
|
private def start(timeout: Duration): Boolean = {
|
||||||
|
handler ? RabitTrackerHandler.StartTracker(
|
||||||
|
new InetSocketAddress(InetAddress.getLocalHost, port.getOrElse(0)), maxPortTrials, timeout)
|
||||||
|
|
||||||
|
// block by waiting for the actor to bind to a port
|
||||||
|
Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
|
||||||
|
.asInstanceOf[Future[Map[String, String]]]) match {
|
||||||
|
case Success(futurePortBound) =>
|
||||||
|
// The success of the Future is contingent on binding to an InetSocketAddress.
|
||||||
|
val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
|
||||||
|
if (isBound) {
|
||||||
|
workerEnvs = Await.result(futurePortBound, 0 nano)
|
||||||
|
}
|
||||||
|
isBound
|
||||||
|
case Failure(ex: Throwable) =>
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Start the Rabit tracker.
|
||||||
|
*
|
||||||
|
* @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
|
||||||
|
* connections. If a non-positive value is provided, the tracker
|
||||||
|
* waits for incoming worker connections indefinitely.
|
||||||
|
* @return Boolean flag indicating if the Rabit tracker starts successfully.
|
||||||
|
*/
|
||||||
|
def start(connectionTimeoutMillis: Long): Boolean = {
|
||||||
|
if (connectionTimeoutMillis <= 0) {
|
||||||
|
start(Duration.Inf)
|
||||||
|
} else {
|
||||||
|
start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a Map of necessary environment variables to initiate Rabit workers.
|
||||||
|
*
|
||||||
|
* @return HashMap containing tracker information.
|
||||||
|
*/
|
||||||
|
def getWorkerEnvs: java.util.Map[String, String] = {
|
||||||
|
new java.util.HashMap((workerEnvs ++ Map(
|
||||||
|
"DMLC_NUM_WORKER" -> numWorkers.toString,
|
||||||
|
"DMLC_NUM_SERVER" -> "0"
|
||||||
|
)).asJava)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
|
||||||
|
* This method blocks until timeout or task completion.
|
||||||
|
*
|
||||||
|
* @param atMost the maximum execution time for the workers. By default,
|
||||||
|
* the tracker waits for the workers indefinitely.
|
||||||
|
* @return 0 if the tasks complete successfully, and non-zero otherwise.
|
||||||
|
*/
|
||||||
|
private def waitFor(atMost: Duration): Int = {
|
||||||
|
// request the completion Future from the tracker actor
|
||||||
|
Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
|
||||||
|
.asInstanceOf[Future[Int]]) match {
|
||||||
|
case Success(futureCompleted) =>
|
||||||
|
// wait for all workers to complete synchronously.
|
||||||
|
val statusCode = Try(Await.result(futureCompleted, atMost)) match {
|
||||||
|
case Success(n) if n == numWorkers =>
|
||||||
|
IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
|
||||||
|
case Success(n) if n < numWorkers =>
|
||||||
|
IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
|
||||||
|
case Failure(e) =>
|
||||||
|
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
|
||||||
|
}
|
||||||
|
system.shutdown()
|
||||||
|
statusCode
|
||||||
|
case Failure(ex: Throwable) =>
|
||||||
|
if (!system.isTerminated) {
|
||||||
|
system.shutdown()
|
||||||
|
}
|
||||||
|
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
|
||||||
|
* This method blocks until timeout or task completion.
|
||||||
|
*
|
||||||
|
* @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
|
||||||
|
* non-positive number is given, the tracker waits indefinitely.
|
||||||
|
* @return 0 if the tasks complete successfully, and non-zero otherwise
|
||||||
|
*/
|
||||||
|
def waitFor(atMostMillis: Long): Int = {
|
||||||
|
if (atMostMillis <= 0) {
|
||||||
|
waitFor(Duration.Inf)
|
||||||
|
} else {
|
||||||
|
waitFor(Duration.fromNanos(atMostMillis * 1e6))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@ -0,0 +1,362 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.rabit.handler
|
||||||
|
|
||||||
|
import java.net.InetSocketAddress
|
||||||
|
import java.util.UUID
|
||||||
|
|
||||||
|
import scala.concurrent.duration._
|
||||||
|
import scala.collection.mutable
|
||||||
|
import scala.concurrent.{Promise, TimeoutException}
|
||||||
|
import akka.io.{IO, Tcp}
|
||||||
|
import akka.actor._
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
|
||||||
|
|
||||||
|
import scala.util.{Failure, Random, Success, Try}
|
||||||
|
|
||||||
|
/** The Akka actor for handling and coordinating Rabit worker connections.
|
||||||
|
* This is the main actor for handling socket connections, interacting with the synchronous
|
||||||
|
* tracker interface, and resolving tree/ring/parent dependencies between workers.
|
||||||
|
*
|
||||||
|
* @param numWorkers Number of workers to track.
|
||||||
|
*/
|
||||||
|
private[scala] class RabitTrackerHandler(numWorkers: Int)
|
||||||
|
extends Actor with ActorLogging {
|
||||||
|
|
||||||
|
import context.system
|
||||||
|
import RabitWorkerHandler._
|
||||||
|
import RabitTrackerHandler._
|
||||||
|
|
||||||
|
private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
|
||||||
|
private[this] val promisedShutdownWorkers = Promise[Int]()
|
||||||
|
private[this] val tcpManager = IO(Tcp)
|
||||||
|
|
||||||
|
// resolves worker connection dependency.
|
||||||
|
val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
|
||||||
|
|
||||||
|
// workers that have sent "shutdown" signal
|
||||||
|
private[this] val shutdownWorkers = mutable.Set.empty[Int]
|
||||||
|
private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
|
||||||
|
private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
|
||||||
|
private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
|
||||||
|
private[this] var maxPortTrials = 0
|
||||||
|
private[this] var workerConnectionTimeout: Duration = Duration.Inf
|
||||||
|
private[this] var portTrials = 0
|
||||||
|
private[this] val startedWorkers = mutable.Set.empty[Int]
|
||||||
|
|
||||||
|
val linkMap = new LinkMap(numWorkers)
|
||||||
|
|
||||||
|
def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
|
||||||
|
rank match {
|
||||||
|
case r if r >= 0 => Some(r)
|
||||||
|
case _ =>
|
||||||
|
jobId match {
|
||||||
|
case "NULL" => None
|
||||||
|
case jid => jobToRankMap.get(jid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
|
||||||
|
* by the RabitWorkerHandler.
|
||||||
|
*
|
||||||
|
* @param event Generic Tcp.Event
|
||||||
|
*/
|
||||||
|
private def handleTcpEvents(event: Tcp.Event): Unit = event match {
|
||||||
|
case Tcp.Bound(local) =>
|
||||||
|
// expect all workers to connect within timeout
|
||||||
|
log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
|
||||||
|
log.info(s"Worker connection timeout is $workerConnectionTimeout.")
|
||||||
|
|
||||||
|
context.setReceiveTimeout(workerConnectionTimeout)
|
||||||
|
promisedWorkerEnvs.success(Map(
|
||||||
|
"DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
|
||||||
|
"DMLC_TRACKER_PORT" -> local.getPort.toString,
|
||||||
|
// not required because the world size will be communicated to the
|
||||||
|
// worker node after the rank is assigned.
|
||||||
|
"rabit_world_size" -> numWorkers.toString
|
||||||
|
))
|
||||||
|
|
||||||
|
case Tcp.CommandFailed(cmd: Tcp.Bind) =>
|
||||||
|
if (portTrials < maxPortTrials) {
|
||||||
|
portTrials += 1
|
||||||
|
tcpManager ! Tcp.Bind(self,
|
||||||
|
new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
|
||||||
|
backlog = 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
case Tcp.Connected(remote, local) =>
|
||||||
|
log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
|
||||||
|
// revoke timeout if all workers have connected.
|
||||||
|
val workerHandler = context.actorOf(RabitWorkerHandler.props(
|
||||||
|
remote.getAddress.getHostAddress, numWorkers, self, sender()
|
||||||
|
), s"ConnectionHandler-${UUID.randomUUID().toString}")
|
||||||
|
val connection = sender()
|
||||||
|
connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
|
||||||
|
|
||||||
|
actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
|
||||||
|
* to interact with the tracker interface.
|
||||||
|
*
|
||||||
|
* @param trackerMsg control messages sent by RabitTracker class.
|
||||||
|
*/
|
||||||
|
private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
|
||||||
|
trackerMsg match {
|
||||||
|
|
||||||
|
case msg: StartTracker =>
|
||||||
|
maxPortTrials = msg.maxPortTrials
|
||||||
|
workerConnectionTimeout = msg.connectionTimeout
|
||||||
|
|
||||||
|
// if the port number is missing, try binding to a random ephemeral port.
|
||||||
|
if (msg.addr.getPort == 0) {
|
||||||
|
tcpManager ! Tcp.Bind(self,
|
||||||
|
new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
|
||||||
|
backlog = 256)
|
||||||
|
} else {
|
||||||
|
tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
|
||||||
|
}
|
||||||
|
sender() ! true
|
||||||
|
|
||||||
|
case RequestBoundFuture =>
|
||||||
|
sender() ! promisedWorkerEnvs.future
|
||||||
|
|
||||||
|
case RequestCompletionFuture =>
|
||||||
|
sender() ! promisedShutdownWorkers.future
|
||||||
|
|
||||||
|
case InterruptTracker(e) =>
|
||||||
|
log.error(e, "Uncaught exception thrown by worker.")
|
||||||
|
// make sure that waitFor() does not hang indefinitely.
|
||||||
|
promisedShutdownWorkers.failure(e)
|
||||||
|
context.stop(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles messages sent by child actors representing connecting Rabit workers, by brokering
|
||||||
|
* messages to the dependency resolver, and processing worker commands.
|
||||||
|
*
|
||||||
|
* @param workerMsg Message sent by RabitWorkerHandler actors.
|
||||||
|
*/
|
||||||
|
private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
|
||||||
|
case req @ RequestAwaitConnWorkers(_, _) =>
|
||||||
|
// since the requester may request to connect to other workers
|
||||||
|
// that have not fully set up, delegate this request to the
|
||||||
|
// dependency resolver which handles the dependencies properly.
|
||||||
|
resolver forward req
|
||||||
|
|
||||||
|
// ---- Rabit worker commands: start/recover/shutdown/print ----
|
||||||
|
case WorkerTrackerPrint(_, _, _, msg) =>
|
||||||
|
log.info(msg.trim)
|
||||||
|
|
||||||
|
case WorkerShutdown(rank, _, _) =>
|
||||||
|
assert(rank >= 0, "Invalid rank.")
|
||||||
|
assert(!shutdownWorkers.contains(rank))
|
||||||
|
shutdownWorkers.add(rank)
|
||||||
|
|
||||||
|
log.info(s"Received shutdown signal from $rank")
|
||||||
|
|
||||||
|
if (shutdownWorkers.size == numWorkers) {
|
||||||
|
promisedShutdownWorkers.success(shutdownWorkers.size)
|
||||||
|
context.stop(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
case WorkerRecover(prevRank, worldSize, jobId) =>
|
||||||
|
assert(prevRank >= 0)
|
||||||
|
sender() ! linkMap.assignRank(prevRank)
|
||||||
|
|
||||||
|
case WorkerStart(rank, worldSize, jobId) =>
|
||||||
|
assert(worldSize == numWorkers || worldSize == -1,
|
||||||
|
s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
|
||||||
|
)
|
||||||
|
|
||||||
|
Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
|
||||||
|
case Success(wkRank) =>
|
||||||
|
if (jobId != "NULL") {
|
||||||
|
jobToRankMap.put(jobId, wkRank)
|
||||||
|
}
|
||||||
|
|
||||||
|
val assignedRank = linkMap.assignRank(wkRank)
|
||||||
|
sender() ! assignedRank
|
||||||
|
resolver ! assignedRank
|
||||||
|
|
||||||
|
log.info("Received start signal from " +
|
||||||
|
s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
|
||||||
|
|
||||||
|
case Failure(ex: IndexOutOfBoundsException) =>
|
||||||
|
// More than worldSize workers have connected, likely due to executor loss.
|
||||||
|
// Since Rabit currently does not support crash recovery (because the Allreduce results
|
||||||
|
// are not cached by the tracker, and because existing workers cannot reestablish
|
||||||
|
// connections to newly spawned executor/worker), the most reasonble action here is to
|
||||||
|
// interrupt the tracker immediate with failure state.
|
||||||
|
log.error("Received invalid start signal from " +
|
||||||
|
s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
|
||||||
|
)
|
||||||
|
promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
|
||||||
|
" received from worker, likely due to executor loss."))
|
||||||
|
|
||||||
|
case Failure(ex) =>
|
||||||
|
log.error(ex, "Unexpected error")
|
||||||
|
promisedShutdownWorkers.failure(ex)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ---- Dependency resolving related messages ----
|
||||||
|
case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
|
||||||
|
log.info(s"Worker $host (rank: $rank) has started.")
|
||||||
|
resolver forward msg
|
||||||
|
|
||||||
|
startedWorkers.add(rank)
|
||||||
|
if (startedWorkers.size == numWorkers) {
|
||||||
|
log.info("All workers have started.")
|
||||||
|
}
|
||||||
|
|
||||||
|
case req @ DropFromWaitingList(_) =>
|
||||||
|
// all peer workers in dependency link map have connected;
|
||||||
|
// forward message to resolver to update dependencies.
|
||||||
|
resolver forward req
|
||||||
|
|
||||||
|
case _ =>
|
||||||
|
}
|
||||||
|
|
||||||
|
def receive: Actor.Receive = {
|
||||||
|
case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
|
||||||
|
case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
|
||||||
|
case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
|
||||||
|
|
||||||
|
case akka.actor.ReceiveTimeout =>
|
||||||
|
if (startedWorkers.size < numWorkers) {
|
||||||
|
promisedShutdownWorkers.failure(
|
||||||
|
new TimeoutException("Timed out waiting for workers to connect: " +
|
||||||
|
s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
|
||||||
|
)
|
||||||
|
context.stop(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
context.setReceiveTimeout(Duration.Undefined)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resolve the dependency between nodes as they connect to the tracker.
|
||||||
|
* The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
|
||||||
|
* and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
|
||||||
|
* connections by workers, this dependency constraint assumes that a worker node connects first
|
||||||
|
* is likely to finish setup first.
|
||||||
|
*/
|
||||||
|
private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
|
||||||
|
import RabitWorkerHandler._
|
||||||
|
|
||||||
|
context.watch(handler)
|
||||||
|
|
||||||
|
case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
|
||||||
|
|
||||||
|
// worker nodes that have connected, but have not send WorkerStarted message.
|
||||||
|
private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
|
||||||
|
private val startedWorkers = mutable.Set.empty[Int]
|
||||||
|
// worker nodes that have started, and await for connections.
|
||||||
|
private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
|
||||||
|
private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
|
||||||
|
|
||||||
|
def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
|
||||||
|
val connSet = awaitConnWorkers.toMap
|
||||||
|
.filterKeys(k => linkSet.contains(k))
|
||||||
|
AwaitingConnections(connSet, linkSet.size - connSet.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
def receive: Actor.Receive = {
|
||||||
|
// a copy of the AssignedRank message that is also sent to the worker
|
||||||
|
case AssignedRank(rank, tree_neighbors, ring, parent) =>
|
||||||
|
// the workers that the worker of given `rank` depends on:
|
||||||
|
// worker of rank K only depends on workers with rank smaller than K.
|
||||||
|
val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
|
||||||
|
.filter{ r => r != -1 && r < rank}
|
||||||
|
|
||||||
|
log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
|
||||||
|
dependencyMap.put(rank, dependentWorkers)
|
||||||
|
|
||||||
|
case RequestAwaitConnWorkers(rank, toConnectSet) =>
|
||||||
|
val promise = Promise[AwaitingConnections]()
|
||||||
|
|
||||||
|
assert(dependencyMap.contains(rank))
|
||||||
|
|
||||||
|
val updatedDependency = dependencyMap(rank) diff startedWorkers
|
||||||
|
if (updatedDependency.isEmpty) {
|
||||||
|
// all dependencies are satisfied
|
||||||
|
log.debug(s"Rank $rank has all dependencies satisfied.")
|
||||||
|
promise.success(awaitingWorkers(toConnectSet))
|
||||||
|
} else {
|
||||||
|
log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
|
||||||
|
// promise is pending fulfillment due to unresolved dependency
|
||||||
|
pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
|
||||||
|
}
|
||||||
|
|
||||||
|
sender() ! promise.future
|
||||||
|
|
||||||
|
case WorkerStarted(_, started, awaitingAcceptance) =>
|
||||||
|
startedWorkers.add(started)
|
||||||
|
if (awaitingAcceptance > 0) {
|
||||||
|
awaitConnWorkers.put(started, sender())
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove the started rank from all dependencies.
|
||||||
|
dependencyMap.remove(started)
|
||||||
|
dependencyMap.foreach { case (r, dset) =>
|
||||||
|
val updatedDependency = dset diff startedWorkers
|
||||||
|
// fulfill the future if all dependencies are met (started.)
|
||||||
|
if (updatedDependency.isEmpty) {
|
||||||
|
log.debug(s"Rank $r has all dependencies satisfied.")
|
||||||
|
pendingFulfillment.remove(r).map{
|
||||||
|
case Fulfillment(toConnectSet, promise) =>
|
||||||
|
promise.success(awaitingWorkers(toConnectSet))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencyMap.update(r, updatedDependency)
|
||||||
|
}
|
||||||
|
|
||||||
|
case DropFromWaitingList(rank) =>
|
||||||
|
assert(awaitConnWorkers.remove(rank).isDefined)
|
||||||
|
|
||||||
|
case Terminated(ref) =>
|
||||||
|
if (ref.equals(handler)) {
|
||||||
|
context.stop(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[scala] object RabitTrackerHandler {
|
||||||
|
// Messages sent by RabitTracker to this RabitTrackerHandler actor
|
||||||
|
trait TrackerControlMessage
|
||||||
|
case object RequestCompletionFuture extends TrackerControlMessage
|
||||||
|
case object RequestBoundFuture extends TrackerControlMessage
|
||||||
|
// Start the Rabit tracker at given socket address awaiting worker connections.
|
||||||
|
// All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
|
||||||
|
// shut down due to timeout.
|
||||||
|
case class StartTracker(addr: InetSocketAddress,
|
||||||
|
maxPortTrials: Int,
|
||||||
|
connectionTimeout: Duration) extends TrackerControlMessage
|
||||||
|
// To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
|
||||||
|
// driver for the distributed training.
|
||||||
|
case class InterruptTracker(e: Throwable) extends TrackerControlMessage
|
||||||
|
|
||||||
|
def props(numWorkers: Int): Props =
|
||||||
|
Props(new RabitTrackerHandler(numWorkers))
|
||||||
|
}
|
||||||
@ -0,0 +1,456 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.rabit.handler
|
||||||
|
|
||||||
|
import java.nio.{ByteBuffer, ByteOrder}
|
||||||
|
|
||||||
|
import akka.io.Tcp
|
||||||
|
import akka.actor._
|
||||||
|
import akka.util.ByteString
|
||||||
|
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
|
||||||
|
|
||||||
|
import scala.concurrent.{Await, Future}
|
||||||
|
import scala.concurrent.duration._
|
||||||
|
import scala.util.Try
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Actor to handle socket communication from worker node.
|
||||||
|
* To handle fragmentation in received data, this class acts like a FSM
|
||||||
|
* (finite-state machine) to keep track of the internal states.
|
||||||
|
*
|
||||||
|
* @param host IP address of the remote worker
|
||||||
|
* @param worldSize number of total workers
|
||||||
|
* @param tracker the RabitTrackerHandler actor reference
|
||||||
|
*/
|
||||||
|
private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
|
||||||
|
connection: ActorRef)
|
||||||
|
extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
|
||||||
|
with ActorLogging with Stash {
|
||||||
|
|
||||||
|
import RabitWorkerHandler._
|
||||||
|
import RabitTrackerHelpers._
|
||||||
|
|
||||||
|
private[this] var rank: Int = 0
|
||||||
|
private[this] var port: Int = 0
|
||||||
|
|
||||||
|
// indicate if the connection is transient (like "print" or "shutdown")
|
||||||
|
private[this] var transient: Boolean = false
|
||||||
|
private[this] var peerClosed: Boolean = false
|
||||||
|
|
||||||
|
// number of workers pending acceptance of current worker
|
||||||
|
private[this] var awaitingAcceptance: Int = 0
|
||||||
|
private[this] var neighboringWorkers = Set.empty[Int]
|
||||||
|
|
||||||
|
// TODO: use a single memory allocation to host all buffers,
|
||||||
|
// including the transient ones for writing.
|
||||||
|
private[this] val readBuffer = ByteBuffer.allocate(4096)
|
||||||
|
.order(ByteOrder.nativeOrder())
|
||||||
|
// in case the received message is longer than needed,
|
||||||
|
// stash the spilled over part in this buffer, and send
|
||||||
|
// to self when transition occurs.
|
||||||
|
private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
|
||||||
|
.order(ByteOrder.nativeOrder())
|
||||||
|
// when setup is complete, need to notify peer handlers
|
||||||
|
// to reduce the awaiting-connection counter.
|
||||||
|
private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
|
||||||
|
|
||||||
|
private def resetBuffers(): Unit = {
|
||||||
|
readBuffer.clear()
|
||||||
|
if (spillOverBuffer.position() > 0) {
|
||||||
|
spillOverBuffer.flip()
|
||||||
|
self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
|
||||||
|
spillOverBuffer.clear()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def stashSpillOver(buf: ByteBuffer): Unit = {
|
||||||
|
if (buf.remaining() > 0) spillOverBuffer.put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
def getNeighboringWorkers: Set[Int] = neighboringWorkers
|
||||||
|
|
||||||
|
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
|
||||||
|
val rank = buffer.getInt()
|
||||||
|
val worldSize = buffer.getInt()
|
||||||
|
val jobId = buffer.getString
|
||||||
|
|
||||||
|
val command = buffer.getString
|
||||||
|
command match {
|
||||||
|
case "start" => WorkerStart(rank, worldSize, jobId)
|
||||||
|
case "shutdown" =>
|
||||||
|
transient = true
|
||||||
|
WorkerShutdown(rank, worldSize, jobId)
|
||||||
|
case "recover" =>
|
||||||
|
require(rank >= 0, "Invalid rank for recovering worker.")
|
||||||
|
WorkerRecover(rank, worldSize, jobId)
|
||||||
|
case "print" =>
|
||||||
|
transient = true
|
||||||
|
WorkerTrackerPrint(rank, worldSize, jobId, buffer.getString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
startWith(AwaitingHandshake, DataStruct())
|
||||||
|
|
||||||
|
when(AwaitingHandshake) {
|
||||||
|
case Event(Tcp.Received(magic), _) =>
|
||||||
|
assert(magic.length == 4)
|
||||||
|
val purportedMagic = magic.asNativeOrderByteBuffer.getInt
|
||||||
|
assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
|
||||||
|
|
||||||
|
// echo back the magic number
|
||||||
|
connection ! Tcp.Write(magic)
|
||||||
|
goto(AwaitingCommand) using StructTrackerCommand
|
||||||
|
}
|
||||||
|
|
||||||
|
when(AwaitingCommand) {
|
||||||
|
case Event(Tcp.Received(bytes), validator) =>
|
||||||
|
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
|
||||||
|
if (validator.verify(readBuffer)) {
|
||||||
|
readBuffer.flip()
|
||||||
|
tracker ! decodeCommand(readBuffer)
|
||||||
|
stashSpillOver(readBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
stay
|
||||||
|
// when rank for a worker is assigned, send encoded rank information
|
||||||
|
// back to worker over Tcp socket.
|
||||||
|
case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
|
||||||
|
log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
|
||||||
|
|
||||||
|
rank = assignedRank
|
||||||
|
// ranks from the ring
|
||||||
|
val ringRanks = List(
|
||||||
|
// ringPrev
|
||||||
|
if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
|
||||||
|
// ringNext
|
||||||
|
if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
// update the set of all linked workers to current worker.
|
||||||
|
neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
|
||||||
|
|
||||||
|
connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
|
||||||
|
// to prevent reading before state transition
|
||||||
|
connection ! Tcp.SuspendReading
|
||||||
|
goto(BuildingLinkMap) using StructNodes
|
||||||
|
}
|
||||||
|
|
||||||
|
when(BuildingLinkMap) {
|
||||||
|
case Event(Tcp.Received(bytes), validator) =>
|
||||||
|
bytes.asByteBuffers.foreach { buf =>
|
||||||
|
readBuffer.put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (validator.verify(readBuffer)) {
|
||||||
|
readBuffer.flip()
|
||||||
|
// for a freshly started worker, numConnected should be 0.
|
||||||
|
val numConnected = readBuffer.getInt()
|
||||||
|
val toConnectSet = neighboringWorkers.diff(
|
||||||
|
(0 until numConnected).map { index => readBuffer.getInt() }.toSet)
|
||||||
|
|
||||||
|
// check which workers are currently awaiting connections
|
||||||
|
tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
|
||||||
|
}
|
||||||
|
stay
|
||||||
|
|
||||||
|
// got a Future from the tracker (resolver) about workers that are
|
||||||
|
// currently awaiting connections (particularly from this node.)
|
||||||
|
case Event(future: Future[_], _) =>
|
||||||
|
// blocks execution until all dependencies for current worker is resolved.
|
||||||
|
Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
|
||||||
|
// numNotReachable is the number of workers that currently
|
||||||
|
// cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
|
||||||
|
// connections from those currently non-reachable nodes in the future.
|
||||||
|
case AwaitingConnections(waitConnNodes, numNotReachable) =>
|
||||||
|
log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
|
||||||
|
val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
|
||||||
|
buf.putInt(waitConnNodes.size).putInt(numNotReachable)
|
||||||
|
buf.flip()
|
||||||
|
|
||||||
|
// cache this message until the final state (SetupComplete)
|
||||||
|
pendingAcknowledgement = Some(AcknowledgeAcceptance(
|
||||||
|
waitConnNodes, numNotReachable))
|
||||||
|
|
||||||
|
connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
|
||||||
|
if (waitConnNodes.isEmpty) {
|
||||||
|
connection ! Tcp.SuspendReading
|
||||||
|
goto(AwaitingErrorCount)
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
waitConnNodes.foreach { case (peerRank, peerRef) =>
|
||||||
|
peerRef ! RequestWorkerHostPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// a countdown for DivulgedHostPort messages.
|
||||||
|
stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
|
||||||
|
val hostBytes = peerHost.getBytes()
|
||||||
|
val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
|
||||||
|
.order(ByteOrder.nativeOrder())
|
||||||
|
buffer.putInt(peerHost.length).put(hostBytes)
|
||||||
|
.putInt(peerPort).putInt(peerRank)
|
||||||
|
|
||||||
|
buffer.flip()
|
||||||
|
connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
|
||||||
|
|
||||||
|
if (data.counter == 0) {
|
||||||
|
// to prevent reading before state transition
|
||||||
|
connection ! Tcp.SuspendReading
|
||||||
|
goto(AwaitingErrorCount)
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
stay using data.decrement()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when(AwaitingErrorCount) {
|
||||||
|
case Event(Tcp.Received(numErrors), _) =>
|
||||||
|
val buf = numErrors.asNativeOrderByteBuffer
|
||||||
|
|
||||||
|
buf.getInt match {
|
||||||
|
case 0 =>
|
||||||
|
stashSpillOver(buf)
|
||||||
|
goto(AwaitingPortNumber)
|
||||||
|
case _ =>
|
||||||
|
stashSpillOver(buf)
|
||||||
|
goto(BuildingLinkMap) using StructNodes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when(AwaitingPortNumber) {
|
||||||
|
case Event(Tcp.Received(assignedPort), _) =>
|
||||||
|
assert(assignedPort.length == 4)
|
||||||
|
port = assignedPort.asNativeOrderByteBuffer.getInt
|
||||||
|
log.debug(s"Rank $rank listening @ $host:$port")
|
||||||
|
// wait until the worker closes connection.
|
||||||
|
if (peerClosed) goto(SetupComplete) else stay
|
||||||
|
|
||||||
|
case Event(Tcp.PeerClosed, _) =>
|
||||||
|
peerClosed = true
|
||||||
|
if (port == 0) stay else goto(SetupComplete)
|
||||||
|
}
|
||||||
|
|
||||||
|
when(SetupComplete) {
|
||||||
|
case Event(ReduceWaitCount(count: Int), _) =>
|
||||||
|
awaitingAcceptance -= count
|
||||||
|
// check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
|
||||||
|
if (awaitingAcceptance == 0 && peerClosed) {
|
||||||
|
tracker ! DropFromWaitingList(rank)
|
||||||
|
// no longer needed.
|
||||||
|
context.stop(self)
|
||||||
|
}
|
||||||
|
stay
|
||||||
|
|
||||||
|
case Event(AcknowledgeAcceptance(peers, numBad), _) =>
|
||||||
|
awaitingAcceptance = numBad
|
||||||
|
tracker ! WorkerStarted(host, rank, awaitingAcceptance)
|
||||||
|
peers.values.foreach { peer =>
|
||||||
|
peer ! ReduceWaitCount(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
|
||||||
|
|
||||||
|
stay
|
||||||
|
|
||||||
|
// can only divulge the complete host and port information
|
||||||
|
// when this worker is declared fully connected (otherwise
|
||||||
|
// port information is still missing.)
|
||||||
|
case Event(RequestWorkerHostPort, _) =>
|
||||||
|
sender() ! DivulgedWorkerHostPort(rank, host, port)
|
||||||
|
stay
|
||||||
|
}
|
||||||
|
|
||||||
|
onTransition {
|
||||||
|
// reset buffer when state transitions as data becomes stale
|
||||||
|
case _ -> SetupComplete =>
|
||||||
|
connection ! Tcp.ResumeReading
|
||||||
|
resetBuffers()
|
||||||
|
if (pendingAcknowledgement.isDefined) {
|
||||||
|
self ! pendingAcknowledgement.get
|
||||||
|
}
|
||||||
|
case _ =>
|
||||||
|
connection ! Tcp.ResumeReading
|
||||||
|
resetBuffers()
|
||||||
|
}
|
||||||
|
|
||||||
|
// default message handler
|
||||||
|
whenUnhandled {
|
||||||
|
case Event(Tcp.PeerClosed, _) =>
|
||||||
|
peerClosed = true
|
||||||
|
if (transient) context.stop(self)
|
||||||
|
stay
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[scala] object RabitWorkerHandler {
|
||||||
|
val MAGIC_NUMBER = 0xff99
|
||||||
|
|
||||||
|
// Finite states of this actor, which acts like a FSM.
|
||||||
|
// The following states are defined in order as the FSM progresses.
|
||||||
|
sealed trait State
|
||||||
|
|
||||||
|
// [1] Initial state, awaiting worker to send magic number per protocol.
|
||||||
|
case object AwaitingHandshake extends State
|
||||||
|
// [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
|
||||||
|
case object AwaitingCommand extends State
|
||||||
|
// [3] Brokers connections between workers per ring/tree/parent link map.
|
||||||
|
case object BuildingLinkMap extends State
|
||||||
|
// [4] A transient state in which the worker reports the number of errors in establishing
|
||||||
|
// connections to other peer workers. If no errors, transition to next state.
|
||||||
|
case object AwaitingErrorCount extends State
|
||||||
|
// [5] Awaiting the worker to report its port number for accepting connections from peer workers.
|
||||||
|
// This port number information is later forwarded to linked workers.
|
||||||
|
case object AwaitingPortNumber extends State
|
||||||
|
// [6] Final state after completing the setup with the connecting worker. At this stage, the
|
||||||
|
// worker will have closed the Tcp connection. The actor remains alive to handle messages from
|
||||||
|
// peer actors representing workers with pending setups.
|
||||||
|
case object SetupComplete extends State
|
||||||
|
|
||||||
|
sealed trait DataField
|
||||||
|
case object IntField extends DataField
|
||||||
|
// an integer preceding the actual string
|
||||||
|
case object StringField extends DataField
|
||||||
|
case object IntSeqField extends DataField
|
||||||
|
|
||||||
|
object DataStruct {
|
||||||
|
def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Internal data pertaining to individual state, used to verify the validity of packets sent by
|
||||||
|
// workers.
|
||||||
|
case class DataStruct(fields: Seq[DataField], counter: Int) {
|
||||||
|
/**
|
||||||
|
* Validate whether the provided buffer is complete (i.e., contains
|
||||||
|
* all data fields specified for this DataStruct.)
|
||||||
|
*
|
||||||
|
* @param buf a byte buffer containing received data.
|
||||||
|
*/
|
||||||
|
def verify(buf: ByteBuffer): Boolean = {
|
||||||
|
if (fields.isEmpty) return true
|
||||||
|
|
||||||
|
val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
|
||||||
|
dupBuf.flip()
|
||||||
|
|
||||||
|
Try(fields.foldLeft(true) {
|
||||||
|
case (complete, field) =>
|
||||||
|
val remBytes = dupBuf.remaining()
|
||||||
|
complete && (remBytes > 0) && (remBytes >= (field match {
|
||||||
|
case IntField =>
|
||||||
|
dupBuf.position(dupBuf.position() + 4)
|
||||||
|
4
|
||||||
|
case StringField =>
|
||||||
|
val strLen = dupBuf.getInt
|
||||||
|
dupBuf.position(dupBuf.position() + strLen)
|
||||||
|
4 + strLen
|
||||||
|
case IntSeqField =>
|
||||||
|
val seqLen = dupBuf.getInt
|
||||||
|
dupBuf.position(dupBuf.position() + seqLen * 4)
|
||||||
|
4 + seqLen * 4
|
||||||
|
}))
|
||||||
|
}).getOrElse(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
def increment(): DataStruct = DataStruct(fields, counter + 1)
|
||||||
|
def decrement(): DataStruct = DataStruct(fields, counter - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
val StructNodes = DataStruct(List(IntSeqField), 0)
|
||||||
|
val StructTrackerCommand = DataStruct(List(
|
||||||
|
IntField, IntField, StringField, StringField
|
||||||
|
), 0)
|
||||||
|
|
||||||
|
// ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
|
||||||
|
|
||||||
|
// RabitWorkerHandler --> RabitTrackerHandler
|
||||||
|
sealed trait RabitWorkerRequest
|
||||||
|
// RabitWorkerHandler <-- RabitTrackerHandler
|
||||||
|
sealed trait RabitWorkerResponse
|
||||||
|
|
||||||
|
// Representations of decoded worker commands.
|
||||||
|
abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
|
||||||
|
def rank: Int
|
||||||
|
def worldSize: Int
|
||||||
|
def jobId: String
|
||||||
|
|
||||||
|
def encode: ByteString = {
|
||||||
|
val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
|
||||||
|
.order(ByteOrder.nativeOrder())
|
||||||
|
|
||||||
|
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
|
||||||
|
.putInt(command.length).put(command.getBytes()).flip()
|
||||||
|
|
||||||
|
ByteString.fromByteBuffer(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
|
||||||
|
extends TrackerCommand("start")
|
||||||
|
case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
|
||||||
|
extends TrackerCommand("shutdown")
|
||||||
|
case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
|
||||||
|
extends TrackerCommand("recover")
|
||||||
|
case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
|
||||||
|
extends TrackerCommand("print") {
|
||||||
|
|
||||||
|
override def encode: ByteString = {
|
||||||
|
val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
|
||||||
|
.order(ByteOrder.nativeOrder())
|
||||||
|
|
||||||
|
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
|
||||||
|
.putInt(command.length).put(command.getBytes())
|
||||||
|
.putInt(msg.length).put(msg.getBytes()).flip()
|
||||||
|
|
||||||
|
ByteString.fromByteBuffer(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request to remove the worker of given rank from the list of workers awaiting peer connections.
|
||||||
|
case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
|
||||||
|
// Notify the tracker that the worker of given rank has finished setup and started.
|
||||||
|
case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
|
||||||
|
extends RabitWorkerRequest
|
||||||
|
// Request the set of workers to connect to, according to the LinkMap structure.
|
||||||
|
case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
|
||||||
|
extends RabitWorkerRequest
|
||||||
|
|
||||||
|
// Request, from the tracker, the set of nodes to connect.
|
||||||
|
case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
|
||||||
|
extends RabitWorkerResponse
|
||||||
|
|
||||||
|
// ---- Messages between ConnectionHandler actors ----
|
||||||
|
sealed trait IntraWorkerMessage
|
||||||
|
|
||||||
|
// Notify neighboring workers to decrease the counter of awaiting workers by `count`.
|
||||||
|
case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
|
||||||
|
// Request host and port information from peer ConnectionHandler actors (acting on behave of
|
||||||
|
// connecting workers.) This message will be brokered by RabitTrackerHandler.
|
||||||
|
case object RequestWorkerHostPort extends IntraWorkerMessage
|
||||||
|
// Response to the above request
|
||||||
|
case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
|
||||||
|
// A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
|
||||||
|
case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
|
||||||
|
extends IntraWorkerMessage
|
||||||
|
|
||||||
|
// ---- End of message definitions ----
|
||||||
|
|
||||||
|
def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
|
||||||
|
Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,136 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.rabit.util
|
||||||
|
|
||||||
|
import java.nio.{ByteBuffer, ByteOrder}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The assigned rank to a connecting Rabit worker, along with the information of the ranks of
|
||||||
|
* its linked peer workers, which are critical to perform Allreduce.
|
||||||
|
* When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
|
||||||
|
* client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
|
||||||
|
* with this class as a message, which is later encoded to byte string, and sent over socket
|
||||||
|
* connection to the worker client.
|
||||||
|
*
|
||||||
|
* @param rank assigned rank (ranked by worker connection order: first worker connecting to the
|
||||||
|
* tracker is assigned rank 0, second with rank 1, etc.)
|
||||||
|
* @param neighbors ranks of neighboring workers in a tree map.
|
||||||
|
* @param ring ranks of neighboring workers in a ring map.
|
||||||
|
* @param parent rank of the parent worker.
|
||||||
|
*/
|
||||||
|
private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
|
||||||
|
ring: (Int, Int), parent: Int) {
|
||||||
|
/**
|
||||||
|
* Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
|
||||||
|
* client.
|
||||||
|
* @param worldSize the number of total distributed workers. Must match `numWorkers` used in
|
||||||
|
* LinkMap.
|
||||||
|
* @return a ByteBuffer containing encoded data.
|
||||||
|
*/
|
||||||
|
def toByteBuffer(worldSize: Int): ByteBuffer = {
|
||||||
|
val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
|
||||||
|
buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
|
||||||
|
// neighbors in tree structure
|
||||||
|
neighbors.foreach { n => buffer.putInt(n) }
|
||||||
|
buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
|
||||||
|
buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
|
||||||
|
|
||||||
|
buffer.flip()
|
||||||
|
buffer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[rabit] class LinkMap(numWorkers: Int) {
|
||||||
|
private def getNeighbors(rank: Int): Seq[Int] = {
|
||||||
|
val rank1 = rank + 1
|
||||||
|
Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
|
||||||
|
r >= 0 && r < numWorkers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct a ring structure that tends to share nodes with the tree.
|
||||||
|
*
|
||||||
|
* @param treeMap
|
||||||
|
* @param parentMap
|
||||||
|
* @param rank
|
||||||
|
* @return Seq[Int] instance starting from rank.
|
||||||
|
*/
|
||||||
|
private def constructShareRing(treeMap: Map[Int, Seq[Int]],
|
||||||
|
parentMap: Map[Int, Int],
|
||||||
|
rank: Int = 0): Seq[Int] = {
|
||||||
|
treeMap(rank).toSet - parentMap(rank) match {
|
||||||
|
case emptySet if emptySet.isEmpty =>
|
||||||
|
List(rank)
|
||||||
|
case connectionSet =>
|
||||||
|
connectionSet.zipWithIndex.foldLeft(List(rank)) {
|
||||||
|
case (ringSeq, (v, cnt)) =>
|
||||||
|
val vConnSeq = constructShareRing(treeMap, parentMap, v)
|
||||||
|
vConnSeq match {
|
||||||
|
case vconn if vconn.size == cnt + 1 =>
|
||||||
|
ringSeq ++ vconn.reverse
|
||||||
|
case vconn =>
|
||||||
|
ringSeq ++ vconn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Construct a ring connection used to recover local data.
|
||||||
|
*
|
||||||
|
* @param treeMap
|
||||||
|
* @param parentMap
|
||||||
|
*/
|
||||||
|
private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
|
||||||
|
assert(parentMap(0) == -1)
|
||||||
|
|
||||||
|
val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
|
||||||
|
assert(sharedRing.length == treeMap.size)
|
||||||
|
|
||||||
|
(0 until numWorkers).map { r =>
|
||||||
|
val rPrev = (r + numWorkers - 1) % numWorkers
|
||||||
|
val rNext = (r + 1) % numWorkers
|
||||||
|
sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
|
||||||
|
}.toMap
|
||||||
|
}
|
||||||
|
|
||||||
|
private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
|
||||||
|
private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
|
||||||
|
private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
|
||||||
|
val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
|
||||||
|
case ((rmap, k), i) =>
|
||||||
|
val kNext = ringMap_(k)._2
|
||||||
|
(rmap ++ Map(kNext -> (i + 1)), kNext)
|
||||||
|
}._1
|
||||||
|
|
||||||
|
val ringMap = ringMap_.map {
|
||||||
|
case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
|
||||||
|
}
|
||||||
|
val treeMap = treeMap_.map {
|
||||||
|
case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
|
||||||
|
}
|
||||||
|
val parentMap = parentMap_.map {
|
||||||
|
case (k, v) if k == 0 =>
|
||||||
|
rMap_(k) -> -1
|
||||||
|
case (k, v) =>
|
||||||
|
rMap_(k) -> rMap_(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
def assignRank(rank: Int): AssignedRank = {
|
||||||
|
AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.rabit.util
|
||||||
|
|
||||||
|
import java.nio.{ByteOrder, ByteBuffer}
|
||||||
|
import akka.util.ByteString
|
||||||
|
|
||||||
|
private[rabit] object RabitTrackerHelpers {
|
||||||
|
implicit class ByteStringHelplers(bs: ByteString) {
|
||||||
|
// Java by default uses big endian. Enforce native endian so that
|
||||||
|
// the byte order is consistent with the workers.
|
||||||
|
def asNativeOrderByteBuffer: ByteBuffer = {
|
||||||
|
bs.asByteBuffer.order(ByteOrder.nativeOrder())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit class ByteBufferHelpers(buf: ByteBuffer) {
|
||||||
|
def getString: String = {
|
||||||
|
val len = buf.getInt()
|
||||||
|
val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
|
||||||
|
buf.get(stringBuffer.array(), 0, len)
|
||||||
|
new String(stringBuffer.array(), "utf-8")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -94,6 +94,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
|||||||
long max_elem = cbatch.offset[cbatch.size];
|
long max_elem = cbatch.offset[cbatch.size];
|
||||||
cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0);
|
cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0);
|
||||||
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
|
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
|
||||||
|
|
||||||
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
|
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
|
||||||
<< "batch.index.length must equal batch.offset.back()";
|
<< "batch.index.length must equal batch.offset.back()";
|
||||||
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
|
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
|
||||||
@ -756,3 +757,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
|||||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: RabitAllreduce
|
||||||
|
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||||
|
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
|
||||||
|
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
|
||||||
|
RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|||||||
@ -303,6 +303,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||||
(JNIEnv *, jclass, jintArray);
|
(JNIEnv *, jclass, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: RabitAllreduce
|
||||||
|
* Signature: (Ljava/nio/ByteBuffer;III)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||||
|
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -0,0 +1,224 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.rabit
|
||||||
|
|
||||||
|
import java.nio.{ByteBuffer, ByteOrder}
|
||||||
|
|
||||||
|
import akka.actor.{ActorRef, ActorSystem}
|
||||||
|
import akka.io.Tcp
|
||||||
|
import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe}
|
||||||
|
import akka.util.ByteString
|
||||||
|
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler
|
||||||
|
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._
|
||||||
|
import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap
|
||||||
|
import org.junit.runner.RunWith
|
||||||
|
import org.scalatest.junit.JUnitRunner
|
||||||
|
import org.scalatest.{FlatSpecLike, Matchers}
|
||||||
|
|
||||||
|
import scala.concurrent.Promise
|
||||||
|
|
||||||
|
object RabitTrackerConnectionHandlerTest {
|
||||||
|
def intSeqToByteString(seq: Seq[Int]): ByteString = {
|
||||||
|
val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder())
|
||||||
|
seq.foreach { i => buf.putInt(i) }
|
||||||
|
buf.flip()
|
||||||
|
ByteString.fromByteBuffer(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@RunWith(classOf[JUnitRunner])
|
||||||
|
class RabitTrackerConnectionHandlerTest
|
||||||
|
extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest"))
|
||||||
|
with FlatSpecLike with Matchers with ImplicitSender {
|
||||||
|
|
||||||
|
import RabitTrackerConnectionHandlerTest._
|
||||||
|
|
||||||
|
val magic = intSeqToByteString(List(0xff99))
|
||||||
|
|
||||||
|
"RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in {
|
||||||
|
val trackerProbe = TestProbe()
|
||||||
|
val connProbe = TestProbe()
|
||||||
|
|
||||||
|
val worldSize = 4
|
||||||
|
|
||||||
|
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
|
||||||
|
trackerProbe.ref, connProbe.ref))
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||||
|
|
||||||
|
// send mock magic number
|
||||||
|
fsm ! Tcp.Received(magic)
|
||||||
|
connProbe.expectMsg(Tcp.Write(magic))
|
||||||
|
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||||
|
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||||
|
// ResumeReading should be seen once state transitions
|
||||||
|
connProbe.expectMsg(Tcp.ResumeReading)
|
||||||
|
|
||||||
|
// send mock tracker command in fragments: the handler should be able to handle it.
|
||||||
|
val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
|
||||||
|
bufRank.putInt(0).putInt(worldSize).flip()
|
||||||
|
|
||||||
|
val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder())
|
||||||
|
bufJobId.putInt(1).put(Array[Byte]('0')).flip()
|
||||||
|
|
||||||
|
val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder())
|
||||||
|
bufCmd.putInt(5).put("start".getBytes()).flip()
|
||||||
|
|
||||||
|
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank))
|
||||||
|
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId))
|
||||||
|
|
||||||
|
// the state should not change for incomplete command data.
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||||
|
|
||||||
|
// send the last fragment, and expect message at tracker actor.
|
||||||
|
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
|
||||||
|
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
|
||||||
|
|
||||||
|
val linkMap = new LinkMap(worldSize)
|
||||||
|
val assignedRank = linkMap.assignRank(0)
|
||||||
|
trackerProbe.reply(assignedRank)
|
||||||
|
|
||||||
|
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
|
||||||
|
assignedRank.toByteBuffer(worldSize)
|
||||||
|
)))
|
||||||
|
|
||||||
|
// reading should be suspended upon transitioning to BuildingLinkMap
|
||||||
|
connProbe.expectMsg(Tcp.SuspendReading)
|
||||||
|
// state should transition with according state data changes.
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
|
||||||
|
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
|
||||||
|
connProbe.expectMsg(Tcp.ResumeReading)
|
||||||
|
|
||||||
|
// since the connection handler in test has rank 0, it will not have any nodes to connect to.
|
||||||
|
fsm ! Tcp.Received(intSeqToByteString(List(0)))
|
||||||
|
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
|
||||||
|
|
||||||
|
// return mock response to the connection handler
|
||||||
|
val awaitConnPromise = Promise[AwaitingConnections]()
|
||||||
|
awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef],
|
||||||
|
fsm.underlyingActor.getNeighboringWorkers.size
|
||||||
|
))
|
||||||
|
fsm ! awaitConnPromise.future
|
||||||
|
connProbe.expectMsg(Tcp.Write(
|
||||||
|
intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size))
|
||||||
|
))
|
||||||
|
connProbe.expectMsg(Tcp.SuspendReading)
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount
|
||||||
|
connProbe.expectMsg(Tcp.ResumeReading)
|
||||||
|
|
||||||
|
// send mock error count (0)
|
||||||
|
fsm ! Tcp.Received(intSeqToByteString(List(0)))
|
||||||
|
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
|
||||||
|
connProbe.expectMsg(Tcp.ResumeReading)
|
||||||
|
|
||||||
|
// simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events.
|
||||||
|
fsm ! Tcp.PeerClosed
|
||||||
|
// state should not transition
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
|
||||||
|
fsm ! Tcp.Received(intSeqToByteString(List(32768)))
|
||||||
|
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
|
||||||
|
connProbe.expectMsg(Tcp.ResumeReading)
|
||||||
|
|
||||||
|
trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2))
|
||||||
|
|
||||||
|
val handlerStopProbe = TestProbe()
|
||||||
|
handlerStopProbe watch fsm
|
||||||
|
|
||||||
|
// simulate connections from other workers by mocking ReduceWaitCount commands
|
||||||
|
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
|
||||||
|
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
|
||||||
|
trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0))
|
||||||
|
handlerStopProbe.expectTerminated(fsm)
|
||||||
|
|
||||||
|
// all done.
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "forward print command to tracker" in {
|
||||||
|
val trackerProbe = TestProbe()
|
||||||
|
val connProbe = TestProbe()
|
||||||
|
|
||||||
|
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
|
||||||
|
trackerProbe.ref, connProbe.ref))
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||||
|
|
||||||
|
fsm ! Tcp.Received(magic)
|
||||||
|
connProbe.expectMsg(Tcp.Write(magic))
|
||||||
|
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||||
|
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||||
|
// ResumeReading should be seen once state transitions
|
||||||
|
connProbe.expectMsg(Tcp.ResumeReading)
|
||||||
|
|
||||||
|
val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
|
||||||
|
fsm ! Tcp.Received(printCmd.encode)
|
||||||
|
|
||||||
|
trackerProbe.expectMsg(printCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "handle spill-over Tcp data correctly between state transition" in {
|
||||||
|
val trackerProbe = TestProbe()
|
||||||
|
val connProbe = TestProbe()
|
||||||
|
|
||||||
|
val worldSize = 4
|
||||||
|
|
||||||
|
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
|
||||||
|
trackerProbe.ref, connProbe.ref))
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||||
|
|
||||||
|
// send mock magic number
|
||||||
|
fsm ! Tcp.Received(magic)
|
||||||
|
connProbe.expectMsg(Tcp.Write(magic))
|
||||||
|
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||||
|
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||||
|
// ResumeReading should be seen once state transitions
|
||||||
|
connProbe.expectMsg(Tcp.ResumeReading)
|
||||||
|
|
||||||
|
// send mock tracker command in fragments: the handler should be able to handle it.
|
||||||
|
val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder())
|
||||||
|
bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0'))
|
||||||
|
.putInt(5).put("start".getBytes())
|
||||||
|
// spilled-over data
|
||||||
|
.putInt(0).flip()
|
||||||
|
|
||||||
|
// send data with 4 extra bytes corresponding to the next state.
|
||||||
|
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
|
||||||
|
|
||||||
|
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
|
||||||
|
|
||||||
|
val linkMap = new LinkMap(worldSize)
|
||||||
|
val assignedRank = linkMap.assignRank(0)
|
||||||
|
trackerProbe.reply(assignedRank)
|
||||||
|
|
||||||
|
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
|
||||||
|
assignedRank.toByteBuffer(worldSize)
|
||||||
|
)))
|
||||||
|
|
||||||
|
// reading should be suspended upon transitioning to BuildingLinkMap
|
||||||
|
connProbe.expectMsg(Tcp.SuspendReading)
|
||||||
|
// state should transition with according state data changes.
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
|
||||||
|
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
|
||||||
|
connProbe.expectMsg(Tcp.ResumeReading)
|
||||||
|
|
||||||
|
// the handler should be able to handle spill-over data, and stash it until state transition.
|
||||||
|
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user