diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala
index 366bf7b3d..58afd82e1 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala
@@ -57,7 +57,7 @@ object CustomObjective {
case e: XGBoostError =>
logger.error(e)
null
- case _ =>
+ case _: Throwable =>
null
}
val grad = new Array[Float](nrow)
diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala
index 4c6adee99..9ac8c2668 100644
--- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala
+++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala
@@ -85,7 +85,7 @@ object XGBoost {
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
- if (tracker.start()) {
+ if (tracker.start(0L)) {
dtrain
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
.reduce((x, y) => x).collect().head
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index f4a05cc1d..bf22f7fcc 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -16,11 +16,10 @@
package ml.dmlc.xgboost4j.scala.spark
-import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
-
-import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
+import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
+import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FSDataInputStream, Path}
@@ -30,6 +29,25 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.{SparkContext, TaskContext}
+import scala.concurrent.duration.{Duration, MILLISECONDS}
+
+object TrackerConf {
+ def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
+}
+
+/**
+ * Rabit tracker configurations.
+ * @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
+ * Set timeout length to zero to disable timeout.
+ * Use a finite, non-zero timeout value to prevent tracker from
+ * hanging indefinitely (supported by "scala" implementation only.)
+ * @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
+ * the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
+ * in Scala without Python components, and with full support of timeouts.
+ * The Scala implementation is currently experimental, use at your own risk.
+ */
+case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
+
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
@@ -80,7 +98,7 @@ object XGBoost extends Serializable {
private[spark] def buildDistributedBoosters(
trainingSet: RDD[MLLabeledPoint],
xgBoostConfMap: Map[String, Any],
- rabitEnv: mutable.Map[String, String],
+ rabitEnv: java.util.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
import DataUtils._
@@ -92,7 +110,7 @@ object XGBoost extends Serializable {
partitionedTrainingSet.mapPartitions {
trainingSamples =>
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
- Rabit.init(rabitEnv.asJava)
+ Rabit.init(rabitEnv)
var booster: Booster = null
if (trainingSamples.hasNext) {
val cacheFileName: String = {
@@ -211,9 +229,21 @@ object XGBoost extends Serializable {
overridedParams
}
- private def startTracker(nWorkers: Int): RabitTracker = {
- val tracker = new RabitTracker(nWorkers)
- require(tracker.start(), "FAULT: Failed to start tracker")
+ private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
+ val tracker: IRabitTracker = trackerConf.trackerImpl match {
+ case "scala" => new RabitTracker(nWorkers)
+ case "python" => new PyRabitTracker(nWorkers)
+ case _ => new PyRabitTracker(nWorkers)
+ }
+
+ val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
+ trackerConf.workerConnectionTimeout.toMillis
+ } else {
+ // 0 == Duration.Inf
+ 0L
+ }
+
+ require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
tracker
}
@@ -227,7 +257,7 @@ object XGBoost extends Serializable {
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
- * true, the user may save the RAM cost for running XGBoost within Spark
+ * true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
@@ -243,19 +273,26 @@ object XGBoost extends Serializable {
" you have to specify the objective type as classification or regression with a" +
" customized objective function")
}
- val tracker = startTracker(nWorkers)
+ val trackerConf = params.get("tracker_conf") match {
+ case None => TrackerConf()
+ case Some(conf: TrackerConf) => conf
+ case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
+ "instance of TrackerConf.")
+ }
+ val tracker = startTracker(nWorkers, trackerConf)
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
- tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing)
+ tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
val sparkJobThread = new Thread() {
override def run() {
// force the job
boosters.foreachPartition(() => _)
}
}
+ sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
- val trackerReturnVal = tracker.waitFor()
+ val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
isClsTask)
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala
index 8d0f60cfd..212daadbc 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala
@@ -16,9 +16,12 @@
package ml.dmlc.xgboost4j.scala.spark.params
+import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.spark.ml.param._
+import scala.concurrent.duration.{Duration, NANOSECONDS}
+
trait GeneralParams extends Params {
/**
@@ -69,7 +72,38 @@ trait GeneralParams extends Params {
*/
val missing = new FloatParam(this, "missing", "the value treated as missing")
+ /**
+ * Rabit tracker configurations. The parameter must be provided as an instance of the
+ * TrackerConf class, which has the following definition:
+ *
+ * case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
+ * trackerImpl: String)
+ *
+ * See below for detailed explanations.
+ *
+ * - trackerImpl: Select the implementation of Rabit tracker.
+ * default: "python"
+ *
+ * Choice between "python" or "scala". The former utilizes the Java wrapper of the
+ * Python Rabit tracker (in dmlc_core), and does not support timeout settings.
+ * The "scala" version removes Python components, and fully supports timeout settings.
+ *
+ * - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
+ * default: 0 millisecond (no timeout)
+ *
+ * The timeout value should take the time of data loading and pre-processing into account,
+ * due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
+ * perform data transformation before calling XGBoost.train(), so that this timeout truly
+ * reflects the connection delay. Set a reasonable timeout value to prevent model
+ * training/testing from hanging indefinitely, possible due to network issues.
+ * Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
+ * Ignored if the tracker implementation is "python".
+ */
+ val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations")
+
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
useExternalMemory -> false, silent -> 0,
- customObj -> null, customEval -> null, missing -> Float.NaN)
+ customObj -> null, customEval -> null, missing -> Float.NaN,
+ trackerConf -> TrackerConf()
+ )
}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitTrackerRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitTrackerRobustnessSuite.scala
new file mode 100644
index 000000000..2d1dc2711
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitTrackerRobustnessSuite.scala
@@ -0,0 +1,169 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
+import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
+import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.FunSuite
+
+
+class RabitTrackerRobustnessSuite extends FunSuite with Utils {
+ test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
+ /*
+ Deliberately create new instances of SparkContext in each unit test to avoid reusing the
+ same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
+ by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
+ corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
+ tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
+ that should be avoided.
+ */
+ val sparkConf = new SparkConf().setMaster("local[*]")
+ .setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
+ implicit val sparkContext = new SparkContext(sparkConf)
+ sparkContext.setLogLevel("ERROR")
+
+ val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
+
+ val tracker = new PyRabitTracker(numWorkers)
+ tracker.start(0)
+ val trackerEnvs = tracker.getWorkerEnvs
+
+ val workerCount: Int = numWorkers
+ /*
+ Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
+ last created worker. A cascading event chain will be triggered once the RuntimeException is
+ thrown: the thread running the dummy spark job (sparkThread) catches the exception and
+ delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
+
+ The Java RabitTracker class reacts to exceptions by killing the spawned process running
+ the Python tracker. If at least one Rabit worker has yet connected to the tracker before
+ it is killed, the resulted connection failure will trigger the Rabit worker to call
+ "exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
+
+ In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
+ isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
+ running in separate containers. However, as unit tests are run in Spark local mode, in which
+ tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
+ ultimately kills the entire process, which also happens to host the Spark driver, causing
+ the entire Spark application to crash.
+
+ To prevent unit tests from crashing, deterministic delays were introduced to make sure that
+ the exception is thrown at last, ideally after all worker connections have been established.
+ For the same reason, the Java RabitTracker class delays the killing of the Python tracker
+ process to ensure that pending worker connections are handled.
+ */
+ val dummyTasks = rdd.mapPartitions { iter =>
+ Rabit.init(trackerEnvs)
+ val index = iter.next()
+ Thread.sleep(100 + index * 10)
+ if (index == workerCount) {
+ // kill the worker by throwing an exception
+ throw new RuntimeException("Worker exception.")
+ }
+ Rabit.shutdown()
+ Iterator(index)
+ }.cache()
+
+ val sparkThread = new Thread() {
+ override def run(): Unit = {
+ // forces a Spark job.
+ dummyTasks.foreachPartition(() => _)
+ }
+ }
+
+ sparkThread.setUncaughtExceptionHandler(tracker)
+ sparkThread.start()
+ assert(tracker.waitFor(0) != 0)
+ sparkContext.stop()
+ }
+
+ test("test Scala RabitTracker's exception handling: it should not hang forever.") {
+ val sparkConf = new SparkConf().setMaster("local[*]")
+ .setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
+ implicit val sparkContext = new SparkContext(sparkConf)
+ sparkContext.setLogLevel("ERROR")
+
+ val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
+
+ val tracker = new ScalaRabitTracker(numWorkers)
+ tracker.start(0)
+ val trackerEnvs = tracker.getWorkerEnvs
+
+ val workerCount: Int = numWorkers
+ val dummyTasks = rdd.mapPartitions { iter =>
+ Rabit.init(trackerEnvs)
+ val index = iter.next()
+ Thread.sleep(100 + index * 10)
+ if (index == workerCount) {
+ // kill the worker by throwing an exception
+ throw new RuntimeException("Worker exception.")
+ }
+ Rabit.shutdown()
+ Iterator(index)
+ }.cache()
+
+ val sparkThread = new Thread() {
+ override def run(): Unit = {
+ // forces a Spark job.
+ dummyTasks.foreachPartition(() => _)
+ }
+ }
+ sparkThread.setUncaughtExceptionHandler(tracker)
+ sparkThread.start()
+ assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
+ sparkContext.stop()
+ }
+
+ test("test Scala RabitTracker's workerConnectionTimeout") {
+ val sparkConf = new SparkConf().setMaster("local[*]")
+ .setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
+ implicit val sparkContext = new SparkContext(sparkConf)
+ sparkContext.setLogLevel("ERROR")
+
+ val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
+
+ val tracker = new ScalaRabitTracker(numWorkers)
+ tracker.start(500)
+ val trackerEnvs = tracker.getWorkerEnvs
+
+ val dummyTasks = rdd.mapPartitions { iter =>
+ val index = iter.next()
+ // simulate that the first worker cannot connect to tracker due to network issues.
+ if (index != 1) {
+ Rabit.init(trackerEnvs)
+ Thread.sleep(1000)
+ Rabit.shutdown()
+ }
+
+ Iterator(index)
+ }.cache()
+
+ val sparkThread = new Thread() {
+ override def run(): Unit = {
+ // forces a Spark job.
+ dummyTasks.foreachPartition(() => _)
+ }
+ }
+ sparkThread.setUncaughtExceptionHandler(tracker)
+ sparkThread.start()
+ // should fail due to connection timeout
+ assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
+ sparkContext.stop()
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
index 5faed7234..1874a3b6d 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
@@ -17,18 +17,60 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
+import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
import scala.collection.mutable.ListBuffer
import scala.util.Random
-
-import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
+import scala.concurrent.duration._
+import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
+import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors}
+import org.apache.spark.ml.linalg.{Vectors, Vector => SparkVector}
import org.apache.spark.rdd.RDD
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
+ test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
+ val vectorLength = 100
+ val rdd = sc.parallelize(
+ (1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
+
+ val tracker = new RabitTracker(numWorkers)
+ tracker.start(0)
+ val trackerEnvs = tracker.getWorkerEnvs
+ val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
+
+ val rawData = rdd.mapPartitions { iter =>
+ Iterator(iter.toArray)
+ }.collect()
+
+ val maxVec = (0 until vectorLength).toArray.map { j =>
+ (0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
+ }
+
+ val allReduceResults = rdd.mapPartitions { iter =>
+ Rabit.init(trackerEnvs)
+ val arr = iter.toArray
+ val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
+ Rabit.shutdown()
+ Iterator(results)
+ }.cache()
+
+ val sparkThread = new Thread() {
+ override def run(): Unit = {
+ allReduceResults.foreachPartition(() => _)
+ val byPartitionResults = allReduceResults.collect()
+ assert(byPartitionResults(0).length == vectorLength)
+ collectedAllReduceResults.put(byPartitionResults(0))
+ }
+ }
+ sparkThread.start()
+ assert(tracker.waitFor(0L) == 0)
+ sparkThread.join()
+
+ assert(collectedAllReduceResults.poll().sameElements(maxVec))
+ }
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = buildTrainingRDD(sc)
@@ -36,7 +78,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
- new scala.collection.mutable.HashMap[String, String],
+ new java.util.HashMap[String, String](),
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
val boosterCount = boosterRDD.count()
assert(boosterCount === 2)
@@ -59,6 +101,21 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
cleanExternalCache("XGBoostSuite")
}
+ test("training with Scala-implemented Rabit tracker") {
+ val eval = new EvalError()
+ val trainingRDD = buildTrainingRDD(sc)
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
+ val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
+ "objective" -> "binary:logistic",
+ "tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
+ nWorkers = numWorkers, useExternalMemory = true)
+ assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
+ testSetDMatrix) < 0.1)
+ }
+
test("test with dense vectors containing missing value") {
def buildDenseRDD(): RDD[LabeledPoint] = {
val nrow = 100
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index 0efddd5dd..00cfed904 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -108,5 +108,17 @@
4.11
test
+
+ com.typesafe.akka
+ akka-actor_${scala.binary.version}
+ 2.3.11
+ compile
+
+
+ com.typesafe.akka
+ akka-testkit_${scala.binary.version}
+ 2.3.11
+ test
+
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java
new file mode 100644
index 000000000..2a2fcd423
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java
@@ -0,0 +1,43 @@
+package ml.dmlc.xgboost4j.java;
+
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Interface for Rabit tracker implementations with three public methods:
+ *
+ * - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
+ * timeout value (in milliseconds.)
+ * - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
+ * - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
+ * milliseconds.
+ *
+ * Each implementation is expected to implement a callback function
+ *
+ * public void uncaughtException(Threat t, Throwable e) { ... }
+ *
+ * to interrupt waitFor() in order to prevent the tracker from hanging indefinitely.
+ *
+ * The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
+ * brokers connections between workers.
+ */
+public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
+ public enum TrackerStatus {
+ SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
+
+ private int statusCode;
+
+ TrackerStatus(int statusCode) {
+ this.statusCode = statusCode;
+ }
+
+ public int getStatusCode() {
+ return this.statusCode;
+ }
+ }
+
+ Map getWorkerEnvs();
+ boolean start(long workerConnectionTimeout);
+ // taskExecutionTimeout has no effect in current version of XGBoost.
+ int waitFor(long taskExecutionTimeout);
+}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java
index 3429dc3dd..6e996494b 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java
@@ -2,6 +2,9 @@ package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.Serializable;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.Arrays;
import java.util.Map;
import org.apache.commons.logging.Log;
@@ -22,6 +25,42 @@ public class Rabit {
}
}
+ public enum OpType implements Serializable {
+ MAX(0), MIN(1), SUM(2), BITWISE_OR(3);
+
+ private int op;
+
+ public int getOperand() {
+ return this.op;
+ }
+
+ OpType(int op) {
+ this.op = op;
+ }
+ }
+
+ public enum DataType implements Serializable {
+ CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4),
+ LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8),
+ LONGLONG(8, 8), ULONGLONG(9, 8);
+
+ private int enumOp;
+ private int size;
+
+ public int getEnumOp() {
+ return this.enumOp;
+ }
+
+ public int getSize() {
+ return this.size;
+ }
+
+ DataType(int enumOp, int size) {
+ this.enumOp = enumOp;
+ this.size = size;
+ }
+ }
+
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
@@ -92,4 +131,30 @@ public class Rabit {
checkCall(XGBoostJNI.RabitGetWorldSize(out));
return out[0];
}
+
+ /**
+ * perform Allreduce on distributed float vectors using operator op.
+ * This implementation of allReduce does not support customized prepare function callback in the
+ * native code, as this function is meant for testing purposes only (to test the Rabit tracker.)
+ *
+ * @param elements local elements on distributed workers.
+ * @param op operator used for Allreduce.
+ * @return All-reduced float elements according to the given operator.
+ */
+ public static float[] allReduce(float[] elements, OpType op) {
+ DataType dataType = DataType.FLOAT;
+ ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
+ .order(ByteOrder.nativeOrder());
+
+ for (float el : elements) {
+ buffer.putFloat(el);
+ }
+ buffer.flip();
+
+ XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
+ float[] results = new float[elements.length];
+ buffer.asFloatBuffer().get(results);
+
+ return results;
+ }
}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
index bc419c564..d2008cd7f 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
@@ -5,15 +5,24 @@ package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
- * Distributed RabitTracker, need to be started on driver code before running distributed jobs.
+ * Java implementation of the Rabit tracker to coordinate distributed workers.
+ * As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
+ * start() and waitFor() methods (i.e., the timeout is infinite.)
+ *
+ * For systems lacking Python environment, or for timeout functionality, consider using the Scala
+ * Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
+ * provides timeout support.
+ *
+ * The tracker must be started on driver node before running distributed jobs.
*/
-public class RabitTracker {
+public class RabitTracker implements IRabitTracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file.
@@ -69,7 +78,6 @@ public class RabitTracker {
}
}
-
public RabitTracker(int numWorkers)
throws XGBoostError {
if (numWorkers < 1) {
@@ -78,6 +86,17 @@ public class RabitTracker {
this.numWorkers = numWorkers;
}
+ public void uncaughtException(Thread t, Throwable e) {
+ logger.error("Uncaught exception thrown by worker:", e);
+ try {
+ Thread.sleep(5000L);
+ } catch (InterruptedException ex) {
+ logger.error(ex);
+ } finally {
+ trackerProcess.get().destroy();
+ }
+ }
+
/**
* Get environments that can be used to pass to worker.
* @return The environment settings.
@@ -126,7 +145,13 @@ public class RabitTracker {
}
}
- public boolean start() {
+ public boolean start(long timeout) {
+ if (timeout > 0L) {
+ logger.warn("Python RabitTracker does not support timeout. " +
+ "The tracker will wait for all workers to connect indefinitely, unless " +
+ "it is interrupted manually. Use the Scala RabitTracker for timeout support.");
+ }
+
if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString());
System.out.println("Tracker started, with env=" + envs.toString());
@@ -142,7 +167,14 @@ public class RabitTracker {
}
}
- public int waitFor() {
+ public int waitFor(long timeout) {
+ if (timeout > 0L) {
+ logger.warn("Python RabitTracker does not support timeout. " +
+ "The tracker will wait for either all workers to finish tasks and send " +
+ "shutdown signal, or manual interruptions. " +
+ "Use the Scala RabitTracker for timeout support.");
+ }
+
try {
trackerProcess.get().waitFor();
int returnVal = trackerProcess.get().exitValue();
@@ -153,7 +185,7 @@ public class RabitTracker {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
- return 1;
+ return TrackerStatus.INTERRUPTED.getStatusCode();
}
}
}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java
index 4ecef65a7..630c61647 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java
@@ -16,6 +16,8 @@
package ml.dmlc.xgboost4j.java;
+import java.nio.ByteBuffer;
+
/**
* xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
@@ -97,4 +99,9 @@ class XGBoostJNI {
public final static native int RabitGetRank(int[] out);
public final static native int RabitGetWorldSize(int[] out);
public final static native int RabitVersionNumber(int[] out);
+
+ // Perform Allreduce operation on data in sendrecvbuf.
+ // This JNI function does not support the callback function for data preparation yet.
+ final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
+ int enum_dtype, int enum_op);
}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala
new file mode 100644
index 000000000..d6ca42e75
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala
@@ -0,0 +1,190 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.rabit
+
+import java.net.{InetAddress, InetSocketAddress}
+
+import akka.actor.ActorSystem
+import akka.pattern.ask
+import ml.dmlc.xgboost4j.java.IRabitTracker
+import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
+
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Future}
+import scala.util.{Failure, Success, Try}
+
+/**
+ * Scala implementation of the Rabit tracker interface without Python dependency.
+ * The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
+ * (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
+ * failures.
+ *
+ * Note that this implementation is currently experimental, and should be used at your own risk.
+ *
+ * Example usage:
+ * {{{
+ * import scala.concurrent.duration._
+ *
+ * val tracker = new RabitTracker(32)
+ * // allow up to 10 minutes for all workers to connect to the tracker.
+ * tracker.start(10 minutes)
+ *
+ * /* ...
+ * launching workers in parallel
+ * ...
+ * */
+ *
+ * // wait for worker execution up to 6 hours.
+ * // providing a finite timeout prevents a long-running task from hanging forever in
+ * // catastrophic events, like the loss of an executor during model training.
+ * tracker.waitFor(6 hours)
+ * }}}
+ *
+ * @param numWorkers Number of distributed workers from which the tracker expects connections.
+ * @param port The minimum port number that the tracker binds to.
+ * If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
+ * @param maxPortTrials The maximum number of trials of socket binding, by sequentially
+ * increasing the port number.
+ */
+private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
+ maxPortTrials: Int = 1000)
+ extends IRabitTracker {
+
+ import scala.collection.JavaConverters._
+
+ require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
+
+ val system = ActorSystem.create("RabitTracker")
+ val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
+ implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
+ private[this] val tcpBindingTimeout: Duration = 1 minute
+
+ var workerEnvs: Map[String, String] = Map.empty
+
+ override def uncaughtException(t: Thread, e: Throwable): Unit = {
+ handler ? RabitTrackerHandler.InterruptTracker(e)
+ }
+
+ /**
+ * Start the Rabit tracker.
+ *
+ * @param timeout The timeout for awaiting connections from worker nodes.
+ * Note that when used in Spark applications, because all Spark transformations are
+ * lazily executed, the I/O time for loading RDDs/DataFrames from external sources
+ * (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
+ * If the timeout value is too small, the Rabit tracker will likely timeout before workers
+ * establishing connections to the tracker, due to the overhead of loading data.
+ * Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
+ * running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
+ * @return Boolean flag indicating if the Rabit tracker starts successfully.
+ */
+ private def start(timeout: Duration): Boolean = {
+ handler ? RabitTrackerHandler.StartTracker(
+ new InetSocketAddress(InetAddress.getLocalHost, port.getOrElse(0)), maxPortTrials, timeout)
+
+ // block by waiting for the actor to bind to a port
+ Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
+ .asInstanceOf[Future[Map[String, String]]]) match {
+ case Success(futurePortBound) =>
+ // The success of the Future is contingent on binding to an InetSocketAddress.
+ val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
+ if (isBound) {
+ workerEnvs = Await.result(futurePortBound, 0 nano)
+ }
+ isBound
+ case Failure(ex: Throwable) =>
+ false
+ }
+ }
+
+ /**
+ * Start the Rabit tracker.
+ *
+ * @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
+ * connections. If a non-positive value is provided, the tracker
+ * waits for incoming worker connections indefinitely.
+ * @return Boolean flag indicating if the Rabit tracker starts successfully.
+ */
+ def start(connectionTimeoutMillis: Long): Boolean = {
+ if (connectionTimeoutMillis <= 0) {
+ start(Duration.Inf)
+ } else {
+ start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
+ }
+ }
+
+ /**
+ * Get a Map of necessary environment variables to initiate Rabit workers.
+ *
+ * @return HashMap containing tracker information.
+ */
+ def getWorkerEnvs: java.util.Map[String, String] = {
+ new java.util.HashMap((workerEnvs ++ Map(
+ "DMLC_NUM_WORKER" -> numWorkers.toString,
+ "DMLC_NUM_SERVER" -> "0"
+ )).asJava)
+ }
+
+ /**
+ * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
+ * This method blocks until timeout or task completion.
+ *
+ * @param atMost the maximum execution time for the workers. By default,
+ * the tracker waits for the workers indefinitely.
+ * @return 0 if the tasks complete successfully, and non-zero otherwise.
+ */
+ private def waitFor(atMost: Duration): Int = {
+ // request the completion Future from the tracker actor
+ Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
+ .asInstanceOf[Future[Int]]) match {
+ case Success(futureCompleted) =>
+ // wait for all workers to complete synchronously.
+ val statusCode = Try(Await.result(futureCompleted, atMost)) match {
+ case Success(n) if n == numWorkers =>
+ IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
+ case Success(n) if n < numWorkers =>
+ IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
+ case Failure(e) =>
+ IRabitTracker.TrackerStatus.FAILURE.getStatusCode
+ }
+ system.shutdown()
+ statusCode
+ case Failure(ex: Throwable) =>
+ if (!system.isTerminated) {
+ system.shutdown()
+ }
+ IRabitTracker.TrackerStatus.FAILURE.getStatusCode
+ }
+ }
+
+ /**
+ * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
+ * This method blocks until timeout or task completion.
+ *
+ * @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
+ * non-positive number is given, the tracker waits indefinitely.
+ * @return 0 if the tasks complete successfully, and non-zero otherwise
+ */
+ def waitFor(atMostMillis: Long): Int = {
+ if (atMostMillis <= 0) {
+ waitFor(Duration.Inf)
+ } else {
+ waitFor(Duration.fromNanos(atMostMillis * 1e6))
+ }
+ }
+}
+
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala
new file mode 100644
index 000000000..8b1c25a34
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala
@@ -0,0 +1,362 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.rabit.handler
+
+import java.net.InetSocketAddress
+import java.util.UUID
+
+import scala.concurrent.duration._
+import scala.collection.mutable
+import scala.concurrent.{Promise, TimeoutException}
+import akka.io.{IO, Tcp}
+import akka.actor._
+import ml.dmlc.xgboost4j.java.XGBoostError
+import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
+
+import scala.util.{Failure, Random, Success, Try}
+
+/** The Akka actor for handling and coordinating Rabit worker connections.
+ * This is the main actor for handling socket connections, interacting with the synchronous
+ * tracker interface, and resolving tree/ring/parent dependencies between workers.
+ *
+ * @param numWorkers Number of workers to track.
+ */
+private[scala] class RabitTrackerHandler(numWorkers: Int)
+ extends Actor with ActorLogging {
+
+ import context.system
+ import RabitWorkerHandler._
+ import RabitTrackerHandler._
+
+ private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
+ private[this] val promisedShutdownWorkers = Promise[Int]()
+ private[this] val tcpManager = IO(Tcp)
+
+ // resolves worker connection dependency.
+ val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
+
+ // workers that have sent "shutdown" signal
+ private[this] val shutdownWorkers = mutable.Set.empty[Int]
+ private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
+ private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
+ private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
+ private[this] var maxPortTrials = 0
+ private[this] var workerConnectionTimeout: Duration = Duration.Inf
+ private[this] var portTrials = 0
+ private[this] val startedWorkers = mutable.Set.empty[Int]
+
+ val linkMap = new LinkMap(numWorkers)
+
+ def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
+ rank match {
+ case r if r >= 0 => Some(r)
+ case _ =>
+ jobId match {
+ case "NULL" => None
+ case jid => jobToRankMap.get(jid)
+ }
+ }
+ }
+
+ /**
+ * Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
+ * by the RabitWorkerHandler.
+ *
+ * @param event Generic Tcp.Event
+ */
+ private def handleTcpEvents(event: Tcp.Event): Unit = event match {
+ case Tcp.Bound(local) =>
+ // expect all workers to connect within timeout
+ log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
+ log.info(s"Worker connection timeout is $workerConnectionTimeout.")
+
+ context.setReceiveTimeout(workerConnectionTimeout)
+ promisedWorkerEnvs.success(Map(
+ "DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
+ "DMLC_TRACKER_PORT" -> local.getPort.toString,
+ // not required because the world size will be communicated to the
+ // worker node after the rank is assigned.
+ "rabit_world_size" -> numWorkers.toString
+ ))
+
+ case Tcp.CommandFailed(cmd: Tcp.Bind) =>
+ if (portTrials < maxPortTrials) {
+ portTrials += 1
+ tcpManager ! Tcp.Bind(self,
+ new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
+ backlog = 256)
+ }
+
+ case Tcp.Connected(remote, local) =>
+ log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
+ // revoke timeout if all workers have connected.
+ val workerHandler = context.actorOf(RabitWorkerHandler.props(
+ remote.getAddress.getHostAddress, numWorkers, self, sender()
+ ), s"ConnectionHandler-${UUID.randomUUID().toString}")
+ val connection = sender()
+ connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
+
+ actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
+ }
+
+ /**
+ * Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
+ * to interact with the tracker interface.
+ *
+ * @param trackerMsg control messages sent by RabitTracker class.
+ */
+ private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
+ trackerMsg match {
+
+ case msg: StartTracker =>
+ maxPortTrials = msg.maxPortTrials
+ workerConnectionTimeout = msg.connectionTimeout
+
+ // if the port number is missing, try binding to a random ephemeral port.
+ if (msg.addr.getPort == 0) {
+ tcpManager ! Tcp.Bind(self,
+ new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
+ backlog = 256)
+ } else {
+ tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
+ }
+ sender() ! true
+
+ case RequestBoundFuture =>
+ sender() ! promisedWorkerEnvs.future
+
+ case RequestCompletionFuture =>
+ sender() ! promisedShutdownWorkers.future
+
+ case InterruptTracker(e) =>
+ log.error(e, "Uncaught exception thrown by worker.")
+ // make sure that waitFor() does not hang indefinitely.
+ promisedShutdownWorkers.failure(e)
+ context.stop(self)
+ }
+
+ /**
+ * Handles messages sent by child actors representing connecting Rabit workers, by brokering
+ * messages to the dependency resolver, and processing worker commands.
+ *
+ * @param workerMsg Message sent by RabitWorkerHandler actors.
+ */
+ private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
+ case req @ RequestAwaitConnWorkers(_, _) =>
+ // since the requester may request to connect to other workers
+ // that have not fully set up, delegate this request to the
+ // dependency resolver which handles the dependencies properly.
+ resolver forward req
+
+ // ---- Rabit worker commands: start/recover/shutdown/print ----
+ case WorkerTrackerPrint(_, _, _, msg) =>
+ log.info(msg.trim)
+
+ case WorkerShutdown(rank, _, _) =>
+ assert(rank >= 0, "Invalid rank.")
+ assert(!shutdownWorkers.contains(rank))
+ shutdownWorkers.add(rank)
+
+ log.info(s"Received shutdown signal from $rank")
+
+ if (shutdownWorkers.size == numWorkers) {
+ promisedShutdownWorkers.success(shutdownWorkers.size)
+ context.stop(self)
+ }
+
+ case WorkerRecover(prevRank, worldSize, jobId) =>
+ assert(prevRank >= 0)
+ sender() ! linkMap.assignRank(prevRank)
+
+ case WorkerStart(rank, worldSize, jobId) =>
+ assert(worldSize == numWorkers || worldSize == -1,
+ s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
+ )
+
+ Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
+ case Success(wkRank) =>
+ if (jobId != "NULL") {
+ jobToRankMap.put(jobId, wkRank)
+ }
+
+ val assignedRank = linkMap.assignRank(wkRank)
+ sender() ! assignedRank
+ resolver ! assignedRank
+
+ log.info("Received start signal from " +
+ s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
+
+ case Failure(ex: IndexOutOfBoundsException) =>
+ // More than worldSize workers have connected, likely due to executor loss.
+ // Since Rabit currently does not support crash recovery (because the Allreduce results
+ // are not cached by the tracker, and because existing workers cannot reestablish
+ // connections to newly spawned executor/worker), the most reasonble action here is to
+ // interrupt the tracker immediate with failure state.
+ log.error("Received invalid start signal from " +
+ s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
+ )
+ promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
+ " received from worker, likely due to executor loss."))
+
+ case Failure(ex) =>
+ log.error(ex, "Unexpected error")
+ promisedShutdownWorkers.failure(ex)
+ }
+
+
+ // ---- Dependency resolving related messages ----
+ case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
+ log.info(s"Worker $host (rank: $rank) has started.")
+ resolver forward msg
+
+ startedWorkers.add(rank)
+ if (startedWorkers.size == numWorkers) {
+ log.info("All workers have started.")
+ }
+
+ case req @ DropFromWaitingList(_) =>
+ // all peer workers in dependency link map have connected;
+ // forward message to resolver to update dependencies.
+ resolver forward req
+
+ case _ =>
+ }
+
+ def receive: Actor.Receive = {
+ case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
+ case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
+ case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
+
+ case akka.actor.ReceiveTimeout =>
+ if (startedWorkers.size < numWorkers) {
+ promisedShutdownWorkers.failure(
+ new TimeoutException("Timed out waiting for workers to connect: " +
+ s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
+ )
+ context.stop(self)
+ }
+
+ context.setReceiveTimeout(Duration.Undefined)
+ }
+}
+
+/**
+ * Resolve the dependency between nodes as they connect to the tracker.
+ * The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
+ * and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
+ * connections by workers, this dependency constraint assumes that a worker node connects first
+ * is likely to finish setup first.
+ */
+private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
+ import RabitWorkerHandler._
+
+ context.watch(handler)
+
+ case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
+
+ // worker nodes that have connected, but have not send WorkerStarted message.
+ private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
+ private val startedWorkers = mutable.Set.empty[Int]
+ // worker nodes that have started, and await for connections.
+ private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
+ private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
+
+ def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
+ val connSet = awaitConnWorkers.toMap
+ .filterKeys(k => linkSet.contains(k))
+ AwaitingConnections(connSet, linkSet.size - connSet.size)
+ }
+
+ def receive: Actor.Receive = {
+ // a copy of the AssignedRank message that is also sent to the worker
+ case AssignedRank(rank, tree_neighbors, ring, parent) =>
+ // the workers that the worker of given `rank` depends on:
+ // worker of rank K only depends on workers with rank smaller than K.
+ val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
+ .filter{ r => r != -1 && r < rank}
+
+ log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
+ dependencyMap.put(rank, dependentWorkers)
+
+ case RequestAwaitConnWorkers(rank, toConnectSet) =>
+ val promise = Promise[AwaitingConnections]()
+
+ assert(dependencyMap.contains(rank))
+
+ val updatedDependency = dependencyMap(rank) diff startedWorkers
+ if (updatedDependency.isEmpty) {
+ // all dependencies are satisfied
+ log.debug(s"Rank $rank has all dependencies satisfied.")
+ promise.success(awaitingWorkers(toConnectSet))
+ } else {
+ log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
+ // promise is pending fulfillment due to unresolved dependency
+ pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
+ }
+
+ sender() ! promise.future
+
+ case WorkerStarted(_, started, awaitingAcceptance) =>
+ startedWorkers.add(started)
+ if (awaitingAcceptance > 0) {
+ awaitConnWorkers.put(started, sender())
+ }
+
+ // remove the started rank from all dependencies.
+ dependencyMap.remove(started)
+ dependencyMap.foreach { case (r, dset) =>
+ val updatedDependency = dset diff startedWorkers
+ // fulfill the future if all dependencies are met (started.)
+ if (updatedDependency.isEmpty) {
+ log.debug(s"Rank $r has all dependencies satisfied.")
+ pendingFulfillment.remove(r).map{
+ case Fulfillment(toConnectSet, promise) =>
+ promise.success(awaitingWorkers(toConnectSet))
+ }
+ }
+
+ dependencyMap.update(r, updatedDependency)
+ }
+
+ case DropFromWaitingList(rank) =>
+ assert(awaitConnWorkers.remove(rank).isDefined)
+
+ case Terminated(ref) =>
+ if (ref.equals(handler)) {
+ context.stop(self)
+ }
+ }
+}
+
+private[scala] object RabitTrackerHandler {
+ // Messages sent by RabitTracker to this RabitTrackerHandler actor
+ trait TrackerControlMessage
+ case object RequestCompletionFuture extends TrackerControlMessage
+ case object RequestBoundFuture extends TrackerControlMessage
+ // Start the Rabit tracker at given socket address awaiting worker connections.
+ // All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
+ // shut down due to timeout.
+ case class StartTracker(addr: InetSocketAddress,
+ maxPortTrials: Int,
+ connectionTimeout: Duration) extends TrackerControlMessage
+ // To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
+ // driver for the distributed training.
+ case class InterruptTracker(e: Throwable) extends TrackerControlMessage
+
+ def props(numWorkers: Int): Props =
+ Props(new RabitTrackerHandler(numWorkers))
+}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala
new file mode 100644
index 000000000..31acfc1ce
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala
@@ -0,0 +1,456 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.rabit.handler
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import akka.io.Tcp
+import akka.actor._
+import akka.util.ByteString
+import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
+
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration._
+import scala.util.Try
+
+/**
+ * Actor to handle socket communication from worker node.
+ * To handle fragmentation in received data, this class acts like a FSM
+ * (finite-state machine) to keep track of the internal states.
+ *
+ * @param host IP address of the remote worker
+ * @param worldSize number of total workers
+ * @param tracker the RabitTrackerHandler actor reference
+ */
+private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
+ connection: ActorRef)
+ extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
+ with ActorLogging with Stash {
+
+ import RabitWorkerHandler._
+ import RabitTrackerHelpers._
+
+ private[this] var rank: Int = 0
+ private[this] var port: Int = 0
+
+ // indicate if the connection is transient (like "print" or "shutdown")
+ private[this] var transient: Boolean = false
+ private[this] var peerClosed: Boolean = false
+
+ // number of workers pending acceptance of current worker
+ private[this] var awaitingAcceptance: Int = 0
+ private[this] var neighboringWorkers = Set.empty[Int]
+
+ // TODO: use a single memory allocation to host all buffers,
+ // including the transient ones for writing.
+ private[this] val readBuffer = ByteBuffer.allocate(4096)
+ .order(ByteOrder.nativeOrder())
+ // in case the received message is longer than needed,
+ // stash the spilled over part in this buffer, and send
+ // to self when transition occurs.
+ private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
+ .order(ByteOrder.nativeOrder())
+ // when setup is complete, need to notify peer handlers
+ // to reduce the awaiting-connection counter.
+ private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
+
+ private def resetBuffers(): Unit = {
+ readBuffer.clear()
+ if (spillOverBuffer.position() > 0) {
+ spillOverBuffer.flip()
+ self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
+ spillOverBuffer.clear()
+ }
+ }
+
+ private def stashSpillOver(buf: ByteBuffer): Unit = {
+ if (buf.remaining() > 0) spillOverBuffer.put(buf)
+ }
+
+ def getNeighboringWorkers: Set[Int] = neighboringWorkers
+
+ def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
+ val rank = buffer.getInt()
+ val worldSize = buffer.getInt()
+ val jobId = buffer.getString
+
+ val command = buffer.getString
+ command match {
+ case "start" => WorkerStart(rank, worldSize, jobId)
+ case "shutdown" =>
+ transient = true
+ WorkerShutdown(rank, worldSize, jobId)
+ case "recover" =>
+ require(rank >= 0, "Invalid rank for recovering worker.")
+ WorkerRecover(rank, worldSize, jobId)
+ case "print" =>
+ transient = true
+ WorkerTrackerPrint(rank, worldSize, jobId, buffer.getString)
+ }
+ }
+
+ startWith(AwaitingHandshake, DataStruct())
+
+ when(AwaitingHandshake) {
+ case Event(Tcp.Received(magic), _) =>
+ assert(magic.length == 4)
+ val purportedMagic = magic.asNativeOrderByteBuffer.getInt
+ assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
+
+ // echo back the magic number
+ connection ! Tcp.Write(magic)
+ goto(AwaitingCommand) using StructTrackerCommand
+ }
+
+ when(AwaitingCommand) {
+ case Event(Tcp.Received(bytes), validator) =>
+ bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
+ if (validator.verify(readBuffer)) {
+ readBuffer.flip()
+ tracker ! decodeCommand(readBuffer)
+ stashSpillOver(readBuffer)
+ }
+
+ stay
+ // when rank for a worker is assigned, send encoded rank information
+ // back to worker over Tcp socket.
+ case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
+ log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
+
+ rank = assignedRank
+ // ranks from the ring
+ val ringRanks = List(
+ // ringPrev
+ if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
+ // ringNext
+ if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
+ )
+
+ // update the set of all linked workers to current worker.
+ neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
+
+ connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
+ // to prevent reading before state transition
+ connection ! Tcp.SuspendReading
+ goto(BuildingLinkMap) using StructNodes
+ }
+
+ when(BuildingLinkMap) {
+ case Event(Tcp.Received(bytes), validator) =>
+ bytes.asByteBuffers.foreach { buf =>
+ readBuffer.put(buf)
+ }
+
+ if (validator.verify(readBuffer)) {
+ readBuffer.flip()
+ // for a freshly started worker, numConnected should be 0.
+ val numConnected = readBuffer.getInt()
+ val toConnectSet = neighboringWorkers.diff(
+ (0 until numConnected).map { index => readBuffer.getInt() }.toSet)
+
+ // check which workers are currently awaiting connections
+ tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
+ }
+ stay
+
+ // got a Future from the tracker (resolver) about workers that are
+ // currently awaiting connections (particularly from this node.)
+ case Event(future: Future[_], _) =>
+ // blocks execution until all dependencies for current worker is resolved.
+ Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
+ // numNotReachable is the number of workers that currently
+ // cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
+ // connections from those currently non-reachable nodes in the future.
+ case AwaitingConnections(waitConnNodes, numNotReachable) =>
+ log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
+ val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
+ buf.putInt(waitConnNodes.size).putInt(numNotReachable)
+ buf.flip()
+
+ // cache this message until the final state (SetupComplete)
+ pendingAcknowledgement = Some(AcknowledgeAcceptance(
+ waitConnNodes, numNotReachable))
+
+ connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
+ if (waitConnNodes.isEmpty) {
+ connection ! Tcp.SuspendReading
+ goto(AwaitingErrorCount)
+ }
+ else {
+ waitConnNodes.foreach { case (peerRank, peerRef) =>
+ peerRef ! RequestWorkerHostPort
+ }
+
+ // a countdown for DivulgedHostPort messages.
+ stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
+ }
+ }
+
+ case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
+ val hostBytes = peerHost.getBytes()
+ val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
+ .order(ByteOrder.nativeOrder())
+ buffer.putInt(peerHost.length).put(hostBytes)
+ .putInt(peerPort).putInt(peerRank)
+
+ buffer.flip()
+ connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
+
+ if (data.counter == 0) {
+ // to prevent reading before state transition
+ connection ! Tcp.SuspendReading
+ goto(AwaitingErrorCount)
+ }
+ else {
+ stay using data.decrement()
+ }
+ }
+
+ when(AwaitingErrorCount) {
+ case Event(Tcp.Received(numErrors), _) =>
+ val buf = numErrors.asNativeOrderByteBuffer
+
+ buf.getInt match {
+ case 0 =>
+ stashSpillOver(buf)
+ goto(AwaitingPortNumber)
+ case _ =>
+ stashSpillOver(buf)
+ goto(BuildingLinkMap) using StructNodes
+ }
+ }
+
+ when(AwaitingPortNumber) {
+ case Event(Tcp.Received(assignedPort), _) =>
+ assert(assignedPort.length == 4)
+ port = assignedPort.asNativeOrderByteBuffer.getInt
+ log.debug(s"Rank $rank listening @ $host:$port")
+ // wait until the worker closes connection.
+ if (peerClosed) goto(SetupComplete) else stay
+
+ case Event(Tcp.PeerClosed, _) =>
+ peerClosed = true
+ if (port == 0) stay else goto(SetupComplete)
+ }
+
+ when(SetupComplete) {
+ case Event(ReduceWaitCount(count: Int), _) =>
+ awaitingAcceptance -= count
+ // check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
+ if (awaitingAcceptance == 0 && peerClosed) {
+ tracker ! DropFromWaitingList(rank)
+ // no longer needed.
+ context.stop(self)
+ }
+ stay
+
+ case Event(AcknowledgeAcceptance(peers, numBad), _) =>
+ awaitingAcceptance = numBad
+ tracker ! WorkerStarted(host, rank, awaitingAcceptance)
+ peers.values.foreach { peer =>
+ peer ! ReduceWaitCount(1)
+ }
+
+ if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
+
+ stay
+
+ // can only divulge the complete host and port information
+ // when this worker is declared fully connected (otherwise
+ // port information is still missing.)
+ case Event(RequestWorkerHostPort, _) =>
+ sender() ! DivulgedWorkerHostPort(rank, host, port)
+ stay
+ }
+
+ onTransition {
+ // reset buffer when state transitions as data becomes stale
+ case _ -> SetupComplete =>
+ connection ! Tcp.ResumeReading
+ resetBuffers()
+ if (pendingAcknowledgement.isDefined) {
+ self ! pendingAcknowledgement.get
+ }
+ case _ =>
+ connection ! Tcp.ResumeReading
+ resetBuffers()
+ }
+
+ // default message handler
+ whenUnhandled {
+ case Event(Tcp.PeerClosed, _) =>
+ peerClosed = true
+ if (transient) context.stop(self)
+ stay
+ }
+}
+
+private[scala] object RabitWorkerHandler {
+ val MAGIC_NUMBER = 0xff99
+
+ // Finite states of this actor, which acts like a FSM.
+ // The following states are defined in order as the FSM progresses.
+ sealed trait State
+
+ // [1] Initial state, awaiting worker to send magic number per protocol.
+ case object AwaitingHandshake extends State
+ // [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
+ case object AwaitingCommand extends State
+ // [3] Brokers connections between workers per ring/tree/parent link map.
+ case object BuildingLinkMap extends State
+ // [4] A transient state in which the worker reports the number of errors in establishing
+ // connections to other peer workers. If no errors, transition to next state.
+ case object AwaitingErrorCount extends State
+ // [5] Awaiting the worker to report its port number for accepting connections from peer workers.
+ // This port number information is later forwarded to linked workers.
+ case object AwaitingPortNumber extends State
+ // [6] Final state after completing the setup with the connecting worker. At this stage, the
+ // worker will have closed the Tcp connection. The actor remains alive to handle messages from
+ // peer actors representing workers with pending setups.
+ case object SetupComplete extends State
+
+ sealed trait DataField
+ case object IntField extends DataField
+ // an integer preceding the actual string
+ case object StringField extends DataField
+ case object IntSeqField extends DataField
+
+ object DataStruct {
+ def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
+ }
+
+ // Internal data pertaining to individual state, used to verify the validity of packets sent by
+ // workers.
+ case class DataStruct(fields: Seq[DataField], counter: Int) {
+ /**
+ * Validate whether the provided buffer is complete (i.e., contains
+ * all data fields specified for this DataStruct.)
+ *
+ * @param buf a byte buffer containing received data.
+ */
+ def verify(buf: ByteBuffer): Boolean = {
+ if (fields.isEmpty) return true
+
+ val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
+ dupBuf.flip()
+
+ Try(fields.foldLeft(true) {
+ case (complete, field) =>
+ val remBytes = dupBuf.remaining()
+ complete && (remBytes > 0) && (remBytes >= (field match {
+ case IntField =>
+ dupBuf.position(dupBuf.position() + 4)
+ 4
+ case StringField =>
+ val strLen = dupBuf.getInt
+ dupBuf.position(dupBuf.position() + strLen)
+ 4 + strLen
+ case IntSeqField =>
+ val seqLen = dupBuf.getInt
+ dupBuf.position(dupBuf.position() + seqLen * 4)
+ 4 + seqLen * 4
+ }))
+ }).getOrElse(false)
+ }
+
+ def increment(): DataStruct = DataStruct(fields, counter + 1)
+ def decrement(): DataStruct = DataStruct(fields, counter - 1)
+ }
+
+ val StructNodes = DataStruct(List(IntSeqField), 0)
+ val StructTrackerCommand = DataStruct(List(
+ IntField, IntField, StringField, StringField
+ ), 0)
+
+ // ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
+
+ // RabitWorkerHandler --> RabitTrackerHandler
+ sealed trait RabitWorkerRequest
+ // RabitWorkerHandler <-- RabitTrackerHandler
+ sealed trait RabitWorkerResponse
+
+ // Representations of decoded worker commands.
+ abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
+ def rank: Int
+ def worldSize: Int
+ def jobId: String
+
+ def encode: ByteString = {
+ val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
+ .order(ByteOrder.nativeOrder())
+
+ buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
+ .putInt(command.length).put(command.getBytes()).flip()
+
+ ByteString.fromByteBuffer(buf)
+ }
+ }
+
+ case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
+ extends TrackerCommand("start")
+ case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
+ extends TrackerCommand("shutdown")
+ case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
+ extends TrackerCommand("recover")
+ case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
+ extends TrackerCommand("print") {
+
+ override def encode: ByteString = {
+ val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
+ .order(ByteOrder.nativeOrder())
+
+ buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
+ .putInt(command.length).put(command.getBytes())
+ .putInt(msg.length).put(msg.getBytes()).flip()
+
+ ByteString.fromByteBuffer(buf)
+ }
+ }
+
+ // Request to remove the worker of given rank from the list of workers awaiting peer connections.
+ case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
+ // Notify the tracker that the worker of given rank has finished setup and started.
+ case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
+ extends RabitWorkerRequest
+ // Request the set of workers to connect to, according to the LinkMap structure.
+ case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
+ extends RabitWorkerRequest
+
+ // Request, from the tracker, the set of nodes to connect.
+ case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
+ extends RabitWorkerResponse
+
+ // ---- Messages between ConnectionHandler actors ----
+ sealed trait IntraWorkerMessage
+
+ // Notify neighboring workers to decrease the counter of awaiting workers by `count`.
+ case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
+ // Request host and port information from peer ConnectionHandler actors (acting on behave of
+ // connecting workers.) This message will be brokered by RabitTrackerHandler.
+ case object RequestWorkerHostPort extends IntraWorkerMessage
+ // Response to the above request
+ case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
+ // A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
+ case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
+ extends IntraWorkerMessage
+
+ // ---- End of message definitions ----
+
+ def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
+ Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala
new file mode 100644
index 000000000..edec4931b
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala
@@ -0,0 +1,136 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.rabit.util
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+/**
+ * The assigned rank to a connecting Rabit worker, along with the information of the ranks of
+ * its linked peer workers, which are critical to perform Allreduce.
+ * When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
+ * client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
+ * with this class as a message, which is later encoded to byte string, and sent over socket
+ * connection to the worker client.
+ *
+ * @param rank assigned rank (ranked by worker connection order: first worker connecting to the
+ * tracker is assigned rank 0, second with rank 1, etc.)
+ * @param neighbors ranks of neighboring workers in a tree map.
+ * @param ring ranks of neighboring workers in a ring map.
+ * @param parent rank of the parent worker.
+ */
+private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
+ ring: (Int, Int), parent: Int) {
+ /**
+ * Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
+ * client.
+ * @param worldSize the number of total distributed workers. Must match `numWorkers` used in
+ * LinkMap.
+ * @return a ByteBuffer containing encoded data.
+ */
+ def toByteBuffer(worldSize: Int): ByteBuffer = {
+ val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
+ buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
+ // neighbors in tree structure
+ neighbors.foreach { n => buffer.putInt(n) }
+ buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
+ buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
+
+ buffer.flip()
+ buffer
+ }
+}
+
+private[rabit] class LinkMap(numWorkers: Int) {
+ private def getNeighbors(rank: Int): Seq[Int] = {
+ val rank1 = rank + 1
+ Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
+ r >= 0 && r < numWorkers
+ }
+ }
+
+ /**
+ * Construct a ring structure that tends to share nodes with the tree.
+ *
+ * @param treeMap
+ * @param parentMap
+ * @param rank
+ * @return Seq[Int] instance starting from rank.
+ */
+ private def constructShareRing(treeMap: Map[Int, Seq[Int]],
+ parentMap: Map[Int, Int],
+ rank: Int = 0): Seq[Int] = {
+ treeMap(rank).toSet - parentMap(rank) match {
+ case emptySet if emptySet.isEmpty =>
+ List(rank)
+ case connectionSet =>
+ connectionSet.zipWithIndex.foldLeft(List(rank)) {
+ case (ringSeq, (v, cnt)) =>
+ val vConnSeq = constructShareRing(treeMap, parentMap, v)
+ vConnSeq match {
+ case vconn if vconn.size == cnt + 1 =>
+ ringSeq ++ vconn.reverse
+ case vconn =>
+ ringSeq ++ vconn
+ }
+ }
+ }
+ }
+ /**
+ * Construct a ring connection used to recover local data.
+ *
+ * @param treeMap
+ * @param parentMap
+ */
+ private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
+ assert(parentMap(0) == -1)
+
+ val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
+ assert(sharedRing.length == treeMap.size)
+
+ (0 until numWorkers).map { r =>
+ val rPrev = (r + numWorkers - 1) % numWorkers
+ val rNext = (r + 1) % numWorkers
+ sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
+ }.toMap
+ }
+
+ private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
+ private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
+ private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
+ val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
+ case ((rmap, k), i) =>
+ val kNext = ringMap_(k)._2
+ (rmap ++ Map(kNext -> (i + 1)), kNext)
+ }._1
+
+ val ringMap = ringMap_.map {
+ case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
+ }
+ val treeMap = treeMap_.map {
+ case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
+ }
+ val parentMap = parentMap_.map {
+ case (k, v) if k == 0 =>
+ rMap_(k) -> -1
+ case (k, v) =>
+ rMap_(k) -> rMap_(v)
+ }
+
+ def assignRank(rank: Int): AssignedRank = {
+ AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala
new file mode 100644
index 000000000..3d7be618d
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala
@@ -0,0 +1,39 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.rabit.util
+
+import java.nio.{ByteOrder, ByteBuffer}
+import akka.util.ByteString
+
+private[rabit] object RabitTrackerHelpers {
+ implicit class ByteStringHelplers(bs: ByteString) {
+ // Java by default uses big endian. Enforce native endian so that
+ // the byte order is consistent with the workers.
+ def asNativeOrderByteBuffer: ByteBuffer = {
+ bs.asByteBuffer.order(ByteOrder.nativeOrder())
+ }
+ }
+
+ implicit class ByteBufferHelpers(buf: ByteBuffer) {
+ def getString: String = {
+ val len = buf.getInt()
+ val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
+ buf.get(stringBuffer.array(), 0, len)
+ new String(stringBuffer.array(), "utf-8")
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
index db4f93b44..0c4a85dcc 100644
--- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
+++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
@@ -94,6 +94,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
long max_elem = cbatch.offset[cbatch.size];
cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0);
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
+
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
<< "batch.index.length must equal batch.offset.back()";
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
@@ -756,3 +757,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
+
+/*
+ * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
+ * Method: RabitAllreduce
+ * Signature: (Ljava/nio/ByteBuffer;III)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
+ (JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
+ void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
+ RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL);
+
+ return 0;
+}
diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h
index 15410abed..8e42eea1c 100644
--- a/jvm-packages/xgboost4j/src/native/xgboost4j.h
+++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h
@@ -303,6 +303,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
(JNIEnv *, jclass, jintArray);
+/*
+ * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
+ * Method: RabitAllreduce
+ * Signature: (Ljava/nio/ByteBuffer;III)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
+ (JNIEnv *, jclass, jobject, jint, jint, jint);
+
#ifdef __cplusplus
}
#endif
diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala
new file mode 100644
index 000000000..ee4febe39
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala
@@ -0,0 +1,224 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.rabit
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import akka.actor.{ActorRef, ActorSystem}
+import akka.io.Tcp
+import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe}
+import akka.util.ByteString
+import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler
+import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._
+import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap
+import org.junit.runner.RunWith
+import org.scalatest.junit.JUnitRunner
+import org.scalatest.{FlatSpecLike, Matchers}
+
+import scala.concurrent.Promise
+
+object RabitTrackerConnectionHandlerTest {
+ def intSeqToByteString(seq: Seq[Int]): ByteString = {
+ val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder())
+ seq.foreach { i => buf.putInt(i) }
+ buf.flip()
+ ByteString.fromByteBuffer(buf)
+ }
+}
+
+@RunWith(classOf[JUnitRunner])
+class RabitTrackerConnectionHandlerTest
+ extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest"))
+ with FlatSpecLike with Matchers with ImplicitSender {
+
+ import RabitTrackerConnectionHandlerTest._
+
+ val magic = intSeqToByteString(List(0xff99))
+
+ "RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in {
+ val trackerProbe = TestProbe()
+ val connProbe = TestProbe()
+
+ val worldSize = 4
+
+ val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
+ trackerProbe.ref, connProbe.ref))
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
+
+ // send mock magic number
+ fsm ! Tcp.Received(magic)
+ connProbe.expectMsg(Tcp.Write(magic))
+
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
+ fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
+ // ResumeReading should be seen once state transitions
+ connProbe.expectMsg(Tcp.ResumeReading)
+
+ // send mock tracker command in fragments: the handler should be able to handle it.
+ val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
+ bufRank.putInt(0).putInt(worldSize).flip()
+
+ val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder())
+ bufJobId.putInt(1).put(Array[Byte]('0')).flip()
+
+ val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder())
+ bufCmd.putInt(5).put("start".getBytes()).flip()
+
+ fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank))
+ fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId))
+
+ // the state should not change for incomplete command data.
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
+
+ // send the last fragment, and expect message at tracker actor.
+ fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
+ trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
+
+ val linkMap = new LinkMap(worldSize)
+ val assignedRank = linkMap.assignRank(0)
+ trackerProbe.reply(assignedRank)
+
+ connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
+ assignedRank.toByteBuffer(worldSize)
+ )))
+
+ // reading should be suspended upon transitioning to BuildingLinkMap
+ connProbe.expectMsg(Tcp.SuspendReading)
+ // state should transition with according state data changes.
+ fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
+ fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
+ connProbe.expectMsg(Tcp.ResumeReading)
+
+ // since the connection handler in test has rank 0, it will not have any nodes to connect to.
+ fsm ! Tcp.Received(intSeqToByteString(List(0)))
+ trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
+
+ // return mock response to the connection handler
+ val awaitConnPromise = Promise[AwaitingConnections]()
+ awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef],
+ fsm.underlyingActor.getNeighboringWorkers.size
+ ))
+ fsm ! awaitConnPromise.future
+ connProbe.expectMsg(Tcp.Write(
+ intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size))
+ ))
+ connProbe.expectMsg(Tcp.SuspendReading)
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount
+ connProbe.expectMsg(Tcp.ResumeReading)
+
+ // send mock error count (0)
+ fsm ! Tcp.Received(intSeqToByteString(List(0)))
+
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
+ connProbe.expectMsg(Tcp.ResumeReading)
+
+ // simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events.
+ fsm ! Tcp.PeerClosed
+ // state should not transition
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
+ fsm ! Tcp.Received(intSeqToByteString(List(32768)))
+
+ fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
+ connProbe.expectMsg(Tcp.ResumeReading)
+
+ trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2))
+
+ val handlerStopProbe = TestProbe()
+ handlerStopProbe watch fsm
+
+ // simulate connections from other workers by mocking ReduceWaitCount commands
+ fsm ! RabitWorkerHandler.ReduceWaitCount(1)
+ fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
+ fsm ! RabitWorkerHandler.ReduceWaitCount(1)
+ trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0))
+ handlerStopProbe.expectTerminated(fsm)
+
+ // all done.
+ }
+
+ it should "forward print command to tracker" in {
+ val trackerProbe = TestProbe()
+ val connProbe = TestProbe()
+
+ val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
+ trackerProbe.ref, connProbe.ref))
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
+
+ fsm ! Tcp.Received(magic)
+ connProbe.expectMsg(Tcp.Write(magic))
+
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
+ fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
+ // ResumeReading should be seen once state transitions
+ connProbe.expectMsg(Tcp.ResumeReading)
+
+ val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
+ fsm ! Tcp.Received(printCmd.encode)
+
+ trackerProbe.expectMsg(printCmd)
+ }
+
+ it should "handle spill-over Tcp data correctly between state transition" in {
+ val trackerProbe = TestProbe()
+ val connProbe = TestProbe()
+
+ val worldSize = 4
+
+ val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
+ trackerProbe.ref, connProbe.ref))
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
+
+ // send mock magic number
+ fsm ! Tcp.Received(magic)
+ connProbe.expectMsg(Tcp.Write(magic))
+
+ fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
+ fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
+ // ResumeReading should be seen once state transitions
+ connProbe.expectMsg(Tcp.ResumeReading)
+
+ // send mock tracker command in fragments: the handler should be able to handle it.
+ val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder())
+ bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0'))
+ .putInt(5).put("start".getBytes())
+ // spilled-over data
+ .putInt(0).flip()
+
+ // send data with 4 extra bytes corresponding to the next state.
+ fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
+
+ trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
+
+ val linkMap = new LinkMap(worldSize)
+ val assignedRank = linkMap.assignRank(0)
+ trackerProbe.reply(assignedRank)
+
+ connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
+ assignedRank.toByteBuffer(worldSize)
+ )))
+
+ // reading should be suspended upon transitioning to BuildingLinkMap
+ connProbe.expectMsg(Tcp.SuspendReading)
+ // state should transition with according state data changes.
+ fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
+ fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
+ connProbe.expectMsg(Tcp.ResumeReading)
+
+ // the handler should be able to handle spill-over data, and stash it until state transition.
+ trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
+ }
+}