[jvm-packages] Scala implementation of the Rabit tracker. (#1612)

* [jvm-packages] Scala implementation of the Rabit tracker.

A Scala implementation of RabitTracker that is interface-interchangable with the
Java implementation, ported from `tracker.py` in the
[dmlc-core project](https://github.com/dmlc/dmlc-core).

* [jvm-packages] Updated Akka dependency in pom.xml.

* Refactored the RabitTracker directory structure.

* Fixed premature stopping of connection handler.

Added a new finite state "AwaitingPortNumber" to explicitly wait for the
worker to send the port, and close the connection. Stopping the actor
prematurely sends a TCP RST to the worker, causing the worker to crash
on AssertionError.

* Added interface IRabitTracker so that user can switch implementations.

* Default timeout duration changes.

* Dependency for Akka tests.

* Removed the main function of RabitTracker.

* A skeleton for testing Akka-based Rabit tracker.

* waitFor() in RabitTracker no longer throws exceptions.

* Completed unit test for the 'start' command of Rabit tracker.

* Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.)

* Fixed the default timeout duration.

* Use Java container to avoid serialization issues due to intermediate wrappers.

* Added tests for Allreduce/model training using Scala Rabit tracker.

* Added spill-over unit test for the Scala Rabit tracker.

* Fixed a typo.

* Overhaul of RabitTracker interface per code review.

  - Removed methods start() waitFor() (no arguments) from IRabitTracker.
  - The timeout in start(timeout) is now worker connection timeout, as tcp
    socket binding timeout is less intuitive.
  - Dropped time unit from start(...) and waitFor(...) methods; the default
    time unit is millisecond.
  - Moved random port number generation into the RabitTrackerHandler.
  - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit.

* More code refactoring and comments.

* Unified timeout constants. Readable tracker status code.

* Add comments to indicate that allReduce is for tests only. Removed all other variants.

* Removed unused imports.

* Simplified signatures of training methods.

 - Moved TrackerConf into parameter map.
 - Changed GeneralParams so that TrackerConf becomes a standalone parameter.
 - Updated test cases accordingly.

* Changed monitoring strategies.

* Reverted monitoring changes.

* Update test case for Rabit AllReduce.

* Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers.

* More comprehensive test cases for exception handling and worker connection timeout.

* Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case.

* Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code.

* Reverted scalastyle-config changes.

* Visibility scope change. Interface tweaks.

* Use match pattern to handle tracker_conf parameter.

* Minor clarification in JNI code.

* Clearer intent in match pattern to suppress warnings.

* Removed Future from constructor. Block in start() and waitFor() instead.

* Revert inadvertent comment changes.

* Removed debugging information.

* Updated test cases that are a bit finicky.

* Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness.
This commit is contained in:
Xin Yin 2016-12-07 09:35:42 -05:00 committed by Nan Zhu
parent 7078c41dad
commit e7fbc8591f
19 changed files with 1910 additions and 25 deletions

View File

@ -57,7 +57,7 @@ object CustomObjective {
case e: XGBoostError => case e: XGBoostError =>
logger.error(e) logger.error(e)
null null
case _ => case _: Throwable =>
null null
} }
val grad = new Array[Float](nrow) val grad = new Array[Float](nrow)

View File

@ -85,7 +85,7 @@ object XGBoost {
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int): def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
XGBoostModel = { XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism) val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
if (tracker.start()) { if (tracker.start(0L)) {
dtrain dtrain
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs)) .mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
.reduce((x, y) => x).collect().head .reduce((x, y) => x).collect().head

View File

@ -16,11 +16,10 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import scala.collection.mutable import scala.collection.mutable
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FSDataInputStream, Path} import org.apache.hadoop.fs.{FSDataInputStream, Path}
@ -30,6 +29,25 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset import org.apache.spark.sql.Dataset
import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.{SparkContext, TaskContext}
import scala.concurrent.duration.{Duration, MILLISECONDS}
object TrackerConf {
def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
}
/**
* Rabit tracker configurations.
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*/
case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
object XGBoost extends Serializable { object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark") private val logger = LogFactory.getLog("XGBoostSpark")
@ -80,7 +98,7 @@ object XGBoost extends Serializable {
private[spark] def buildDistributedBoosters( private[spark] def buildDistributedBoosters(
trainingSet: RDD[MLLabeledPoint], trainingSet: RDD[MLLabeledPoint],
xgBoostConfMap: Map[String, Any], xgBoostConfMap: Map[String, Any],
rabitEnv: mutable.Map[String, String], rabitEnv: java.util.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait, numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = { useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
import DataUtils._ import DataUtils._
@ -92,7 +110,7 @@ object XGBoost extends Serializable {
partitionedTrainingSet.mapPartitions { partitionedTrainingSet.mapPartitions {
trainingSamples => trainingSamples =>
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv)
var booster: Booster = null var booster: Booster = null
if (trainingSamples.hasNext) { if (trainingSamples.hasNext) {
val cacheFileName: String = { val cacheFileName: String = {
@ -211,9 +229,21 @@ object XGBoost extends Serializable {
overridedParams overridedParams
} }
private def startTracker(nWorkers: Int): RabitTracker = { private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker = new RabitTracker(nWorkers) val tracker: IRabitTracker = trackerConf.trackerImpl match {
require(tracker.start(), "FAULT: Failed to start tracker") case "scala" => new RabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers)
case _ => new PyRabitTracker(nWorkers)
}
val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
trackerConf.workerConnectionTimeout.toMillis
} else {
// 0 == Duration.Inf
0L
}
require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
tracker tracker
} }
@ -243,19 +273,26 @@ object XGBoost extends Serializable {
" you have to specify the objective type as classification or regression with a" + " you have to specify the objective type as classification or regression with a" +
" customized objective function") " customized objective function")
} }
val tracker = startTracker(nWorkers) val trackerConf = params.get("tracker_conf") match {
case None => TrackerConf()
case Some(conf: TrackerConf) => conf
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
"instance of TrackerConf.")
}
val tracker = startTracker(nWorkers, trackerConf)
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext) val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
val boosters = buildDistributedBoosters(trainingData, overridedConfMap, val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing) tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
val sparkJobThread = new Thread() { val sparkJobThread = new Thread() {
override def run() { override def run() {
// force the job // force the job
boosters.foreachPartition(() => _) boosters.foreachPartition(() => _)
} }
} }
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start() sparkJobThread.start()
val isClsTask = isClassificationTask(params) val isClsTask = isClassificationTask(params)
val trackerReturnVal = tracker.waitFor() val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal") logger.info(s"Rabit returns with exit code $trackerReturnVal")
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread, postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
isClsTask) isClsTask)

View File

@ -16,9 +16,12 @@
package ml.dmlc.xgboost4j.scala.spark.params package ml.dmlc.xgboost4j.scala.spark.params
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import scala.concurrent.duration.{Duration, NANOSECONDS}
trait GeneralParams extends Params { trait GeneralParams extends Params {
/** /**
@ -69,7 +72,38 @@ trait GeneralParams extends Params {
*/ */
val missing = new FloatParam(this, "missing", "the value treated as missing") val missing = new FloatParam(this, "missing", "the value treated as missing")
/**
* Rabit tracker configurations. The parameter must be provided as an instance of the
* TrackerConf class, which has the following definition:
*
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
* trackerImpl: String)
*
* See below for detailed explanations.
*
* - trackerImpl: Select the implementation of Rabit tracker.
* default: "python"
*
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
* The "scala" version removes Python components, and fully supports timeout settings.
*
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
* default: 0 millisecond (no timeout)
*
* The timeout value should take the time of data loading and pre-processing into account,
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
* perform data transformation before calling XGBoost.train(), so that this timeout truly
* reflects the connection delay. Set a reasonable timeout value to prevent model
* training/testing from hanging indefinitely, possible due to network issues.
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
* Ignored if the tracker implementation is "python".
*/
val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations")
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1, setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
useExternalMemory -> false, silent -> 0, useExternalMemory -> false, silent -> 0,
customObj -> null, customEval -> null, missing -> Float.NaN) customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf()
)
} }

View File

@ -0,0 +1,169 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.FunSuite
class RabitTrackerRobustnessSuite extends FunSuite with Utils {
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
that should be avoided.
*/
val sparkConf = new SparkConf().setMaster("local[*]")
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
implicit val sparkContext = new SparkContext(sparkConf)
sparkContext.setLogLevel("ERROR")
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new PyRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
/*
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
last created worker. A cascading event chain will be triggered once the RuntimeException is
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
The Java RabitTracker class reacts to exceptions by killing the spawned process running
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
it is killed, the resulted connection failure will trigger the Rabit worker to call
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
running in separate containers. However, as unit tests are run in Spark local mode, in which
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
ultimately kills the entire process, which also happens to host the Spark driver, causing
the entire Spark application to crash.
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
the exception is thrown at last, ideally after all worker connections have been established.
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
process to ensure that pending worker connections are handled.
*/
val dummyTasks = rdd.mapPartitions { iter =>
Rabit.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Rabit.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0) != 0)
sparkContext.stop()
}
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
val sparkConf = new SparkConf().setMaster("local[*]")
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
implicit val sparkContext = new SparkContext(sparkConf)
sparkContext.setLogLevel("ERROR")
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
val dummyTasks = rdd.mapPartitions { iter =>
Rabit.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Rabit.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
sparkContext.stop()
}
test("test Scala RabitTracker's workerConnectionTimeout") {
val sparkConf = new SparkConf().setMaster("local[*]")
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
implicit val sparkContext = new SparkContext(sparkConf)
sparkContext.setLogLevel("ERROR")
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(500)
val trackerEnvs = tracker.getWorkerEnvs
val dummyTasks = rdd.mapPartitions { iter =>
val index = iter.next()
// simulate that the first worker cannot connect to tracker due to network issues.
if (index != 1) {
Rabit.init(trackerEnvs)
Thread.sleep(1000)
Rabit.shutdown()
}
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
// should fail due to connection timeout
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
sparkContext.stop()
}
}

View File

@ -17,18 +17,60 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files import java.nio.file.Files
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import scala.util.Random import scala.util.Random
import scala.concurrent.duration._
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors} import org.apache.spark.ml.linalg.{Vectors, Vector => SparkVector}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
class XGBoostGeneralSuite extends SharedSparkContext with Utils { class XGBoostGeneralSuite extends SharedSparkContext with Utils {
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
val vectorLength = 100
val rdd = sc.parallelize(
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
val rawData = rdd.mapPartitions { iter =>
Iterator(iter.toArray)
}.collect()
val maxVec = (0 until vectorLength).toArray.map { j =>
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
}
val allReduceResults = rdd.mapPartitions { iter =>
Rabit.init(trackerEnvs)
val arr = iter.toArray
val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
Rabit.shutdown()
Iterator(results)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
allReduceResults.foreachPartition(() => _)
val byPartitionResults = allReduceResults.collect()
assert(byPartitionResults(0).length == vectorLength)
collectedAllReduceResults.put(byPartitionResults(0))
}
}
sparkThread.start()
assert(tracker.waitFor(0L) == 0)
sparkThread.join()
assert(collectedAllReduceResults.poll().sameElements(maxVec))
}
test("build RDD containing boosters with the specified worker number") { test("build RDD containing boosters with the specified worker number") {
val trainingRDD = buildTrainingRDD(sc) val trainingRDD = buildTrainingRDD(sc)
@ -36,7 +78,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
trainingRDD, trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap, "objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String], new java.util.HashMap[String, String](),
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true) numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
val boosterCount = boosterRDD.count() val boosterCount = boosterRDD.count()
assert(boosterCount === 2) assert(boosterCount === 2)
@ -59,6 +101,21 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
cleanExternalCache("XGBoostSuite") cleanExternalCache("XGBoostSuite")
} }
test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError()
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
}
test("test with dense vectors containing missing value") { test("test with dense vectors containing missing value") {
def buildDenseRDD(): RDD[LabeledPoint] = { def buildDenseRDD(): RDD[LabeledPoint] = {
val nrow = 100 val nrow = 100

View File

@ -108,5 +108,17 @@
<version>4.11</version> <version>4.11</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor_${scala.binary.version}</artifactId>
<version>2.3.11</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
<version>2.3.11</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -0,0 +1,43 @@
package ml.dmlc.xgboost4j.java;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* Interface for Rabit tracker implementations with three public methods:
*
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
* timeout value (in milliseconds.)
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds.
*
* Each implementation is expected to implement a callback function
*
* public void uncaughtException(Threat t, Throwable e) { ... }
*
* to interrupt waitFor() in order to prevent the tracker from hanging indefinitely.
*
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
* brokers connections between workers.
*/
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
public enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
private int statusCode;
TrackerStatus(int statusCode) {
this.statusCode = statusCode;
}
public int getStatusCode() {
return this.statusCode;
}
}
Map<String, String> getWorkerEnvs();
boolean start(long workerConnectionTimeout);
// taskExecutionTimeout has no effect in current version of XGBoost.
int waitFor(long taskExecutionTimeout);
}

View File

@ -2,6 +2,9 @@ package ml.dmlc.xgboost4j.java;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Map; import java.util.Map;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -22,6 +25,42 @@ public class Rabit {
} }
} }
public enum OpType implements Serializable {
MAX(0), MIN(1), SUM(2), BITWISE_OR(3);
private int op;
public int getOperand() {
return this.op;
}
OpType(int op) {
this.op = op;
}
}
public enum DataType implements Serializable {
CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4),
LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8),
LONGLONG(8, 8), ULONGLONG(9, 8);
private int enumOp;
private int size;
public int getEnumOp() {
return this.enumOp;
}
public int getSize() {
return this.size;
}
DataType(int enumOp, int size) {
this.enumOp = enumOp;
this.size = size;
}
}
private static void checkCall(int ret) throws XGBoostError { private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) { if (ret != 0) {
throw new XGBoostError(XGBoostJNI.XGBGetLastError()); throw new XGBoostError(XGBoostJNI.XGBGetLastError());
@ -92,4 +131,30 @@ public class Rabit {
checkCall(XGBoostJNI.RabitGetWorldSize(out)); checkCall(XGBoostJNI.RabitGetWorldSize(out));
return out[0]; return out[0];
} }
/**
* perform Allreduce on distributed float vectors using operator op.
* This implementation of allReduce does not support customized prepare function callback in the
* native code, as this function is meant for testing purposes only (to test the Rabit tracker.)
*
* @param elements local elements on distributed workers.
* @param op operator used for Allreduce.
* @return All-reduced float elements according to the given operator.
*/
public static float[] allReduce(float[] elements, OpType op) {
DataType dataType = DataType.FLOAT;
ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length)
.order(ByteOrder.nativeOrder());
for (float el : elements) {
buffer.putFloat(el);
}
buffer.flip();
XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
float[] results = new float[elements.length];
buffer.asFloatBuffer().get(results);
return results;
}
} }

View File

@ -5,15 +5,24 @@ package ml.dmlc.xgboost4j.java;
import java.io.*; import java.io.*;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
/** /**
* Distributed RabitTracker, need to be started on driver code before running distributed jobs. * Java implementation of the Rabit tracker to coordinate distributed workers.
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
* start() and waitFor() methods (i.e., the timeout is infinite.)
*
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
* provides timeout support.
*
* The tracker must be started on driver node before running distributed jobs.
*/ */
public class RabitTracker { public class RabitTracker implements IRabitTracker {
// Maybe per tracker logger? // Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class); private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file. // tracker python file.
@ -69,7 +78,6 @@ public class RabitTracker {
} }
} }
public RabitTracker(int numWorkers) public RabitTracker(int numWorkers)
throws XGBoostError { throws XGBoostError {
if (numWorkers < 1) { if (numWorkers < 1) {
@ -78,6 +86,17 @@ public class RabitTracker {
this.numWorkers = numWorkers; this.numWorkers = numWorkers;
} }
public void uncaughtException(Thread t, Throwable e) {
logger.error("Uncaught exception thrown by worker:", e);
try {
Thread.sleep(5000L);
} catch (InterruptedException ex) {
logger.error(ex);
} finally {
trackerProcess.get().destroy();
}
}
/** /**
* Get environments that can be used to pass to worker. * Get environments that can be used to pass to worker.
* @return The environment settings. * @return The environment settings.
@ -126,7 +145,13 @@ public class RabitTracker {
} }
} }
public boolean start() { public boolean start(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for all workers to connect indefinitely, unless " +
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
}
if (startTrackerProcess()) { if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString()); logger.debug("Tracker started, with env=" + envs.toString());
System.out.println("Tracker started, with env=" + envs.toString()); System.out.println("Tracker started, with env=" + envs.toString());
@ -142,7 +167,14 @@ public class RabitTracker {
} }
} }
public int waitFor() { public int waitFor(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for either all workers to finish tasks and send " +
"shutdown signal, or manual interruptions. " +
"Use the Scala RabitTracker for timeout support.");
}
try { try {
trackerProcess.get().waitFor(); trackerProcess.get().waitFor();
int returnVal = trackerProcess.get().exitValue(); int returnVal = trackerProcess.get().exitValue();
@ -153,7 +185,7 @@ public class RabitTracker {
// we should not get here as RabitTracker is accessed in the main thread // we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace(); e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly"); logger.error("the RabitTracker thread is terminated unexpectedly");
return 1; return TrackerStatus.INTERRUPTED.getStatusCode();
} }
} }
} }

View File

@ -16,6 +16,8 @@
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import java.nio.ByteBuffer;
/** /**
* xgboost JNI functions * xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
@ -97,4 +99,9 @@ class XGBoostJNI {
public final static native int RabitGetRank(int[] out); public final static native int RabitGetRank(int[] out);
public final static native int RabitGetWorldSize(int[] out); public final static native int RabitGetWorldSize(int[] out);
public final static native int RabitVersionNumber(int[] out); public final static native int RabitVersionNumber(int[] out);
// Perform Allreduce operation on data in sendrecvbuf.
// This JNI function does not support the callback function for data preparation yet.
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
} }

View File

@ -0,0 +1,190 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit
import java.net.{InetAddress, InetSocketAddress}
import akka.actor.ActorSystem
import akka.pattern.ask
import ml.dmlc.xgboost4j.java.IRabitTracker
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.{Failure, Success, Try}
/**
* Scala implementation of the Rabit tracker interface without Python dependency.
* The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
* (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
* failures.
*
* Note that this implementation is currently experimental, and should be used at your own risk.
*
* Example usage:
* {{{
* import scala.concurrent.duration._
*
* val tracker = new RabitTracker(32)
* // allow up to 10 minutes for all workers to connect to the tracker.
* tracker.start(10 minutes)
*
* /* ...
* launching workers in parallel
* ...
* */
*
* // wait for worker execution up to 6 hours.
* // providing a finite timeout prevents a long-running task from hanging forever in
* // catastrophic events, like the loss of an executor during model training.
* tracker.waitFor(6 hours)
* }}}
*
* @param numWorkers Number of distributed workers from which the tracker expects connections.
* @param port The minimum port number that the tracker binds to.
* If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
* @param maxPortTrials The maximum number of trials of socket binding, by sequentially
* increasing the port number.
*/
private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
maxPortTrials: Int = 1000)
extends IRabitTracker {
import scala.collection.JavaConverters._
require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
val system = ActorSystem.create("RabitTracker")
val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
private[this] val tcpBindingTimeout: Duration = 1 minute
var workerEnvs: Map[String, String] = Map.empty
override def uncaughtException(t: Thread, e: Throwable): Unit = {
handler ? RabitTrackerHandler.InterruptTracker(e)
}
/**
* Start the Rabit tracker.
*
* @param timeout The timeout for awaiting connections from worker nodes.
* Note that when used in Spark applications, because all Spark transformations are
* lazily executed, the I/O time for loading RDDs/DataFrames from external sources
* (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
* If the timeout value is too small, the Rabit tracker will likely timeout before workers
* establishing connections to the tracker, due to the overhead of loading data.
* Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
* running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
* @return Boolean flag indicating if the Rabit tracker starts successfully.
*/
private def start(timeout: Duration): Boolean = {
handler ? RabitTrackerHandler.StartTracker(
new InetSocketAddress(InetAddress.getLocalHost, port.getOrElse(0)), maxPortTrials, timeout)
// block by waiting for the actor to bind to a port
Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
.asInstanceOf[Future[Map[String, String]]]) match {
case Success(futurePortBound) =>
// The success of the Future is contingent on binding to an InetSocketAddress.
val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
if (isBound) {
workerEnvs = Await.result(futurePortBound, 0 nano)
}
isBound
case Failure(ex: Throwable) =>
false
}
}
/**
* Start the Rabit tracker.
*
* @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
* connections. If a non-positive value is provided, the tracker
* waits for incoming worker connections indefinitely.
* @return Boolean flag indicating if the Rabit tracker starts successfully.
*/
def start(connectionTimeoutMillis: Long): Boolean = {
if (connectionTimeoutMillis <= 0) {
start(Duration.Inf)
} else {
start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
}
}
/**
* Get a Map of necessary environment variables to initiate Rabit workers.
*
* @return HashMap containing tracker information.
*/
def getWorkerEnvs: java.util.Map[String, String] = {
new java.util.HashMap((workerEnvs ++ Map(
"DMLC_NUM_WORKER" -> numWorkers.toString,
"DMLC_NUM_SERVER" -> "0"
)).asJava)
}
/**
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
* This method blocks until timeout or task completion.
*
* @param atMost the maximum execution time for the workers. By default,
* the tracker waits for the workers indefinitely.
* @return 0 if the tasks complete successfully, and non-zero otherwise.
*/
private def waitFor(atMost: Duration): Int = {
// request the completion Future from the tracker actor
Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
.asInstanceOf[Future[Int]]) match {
case Success(futureCompleted) =>
// wait for all workers to complete synchronously.
val statusCode = Try(Await.result(futureCompleted, atMost)) match {
case Success(n) if n == numWorkers =>
IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
case Success(n) if n < numWorkers =>
IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
case Failure(e) =>
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
}
system.shutdown()
statusCode
case Failure(ex: Throwable) =>
if (!system.isTerminated) {
system.shutdown()
}
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
}
}
/**
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
* This method blocks until timeout or task completion.
*
* @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
* non-positive number is given, the tracker waits indefinitely.
* @return 0 if the tasks complete successfully, and non-zero otherwise
*/
def waitFor(atMostMillis: Long): Int = {
if (atMostMillis <= 0) {
waitFor(Duration.Inf)
} else {
waitFor(Duration.fromNanos(atMostMillis * 1e6))
}
}
}

View File

@ -0,0 +1,362 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.handler
import java.net.InetSocketAddress
import java.util.UUID
import scala.concurrent.duration._
import scala.collection.mutable
import scala.concurrent.{Promise, TimeoutException}
import akka.io.{IO, Tcp}
import akka.actor._
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
import scala.util.{Failure, Random, Success, Try}
/** The Akka actor for handling and coordinating Rabit worker connections.
* This is the main actor for handling socket connections, interacting with the synchronous
* tracker interface, and resolving tree/ring/parent dependencies between workers.
*
* @param numWorkers Number of workers to track.
*/
private[scala] class RabitTrackerHandler(numWorkers: Int)
extends Actor with ActorLogging {
import context.system
import RabitWorkerHandler._
import RabitTrackerHandler._
private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
private[this] val promisedShutdownWorkers = Promise[Int]()
private[this] val tcpManager = IO(Tcp)
// resolves worker connection dependency.
val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
// workers that have sent "shutdown" signal
private[this] val shutdownWorkers = mutable.Set.empty[Int]
private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
private[this] var maxPortTrials = 0
private[this] var workerConnectionTimeout: Duration = Duration.Inf
private[this] var portTrials = 0
private[this] val startedWorkers = mutable.Set.empty[Int]
val linkMap = new LinkMap(numWorkers)
def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
rank match {
case r if r >= 0 => Some(r)
case _ =>
jobId match {
case "NULL" => None
case jid => jobToRankMap.get(jid)
}
}
}
/**
* Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
* by the RabitWorkerHandler.
*
* @param event Generic Tcp.Event
*/
private def handleTcpEvents(event: Tcp.Event): Unit = event match {
case Tcp.Bound(local) =>
// expect all workers to connect within timeout
log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
log.info(s"Worker connection timeout is $workerConnectionTimeout.")
context.setReceiveTimeout(workerConnectionTimeout)
promisedWorkerEnvs.success(Map(
"DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
"DMLC_TRACKER_PORT" -> local.getPort.toString,
// not required because the world size will be communicated to the
// worker node after the rank is assigned.
"rabit_world_size" -> numWorkers.toString
))
case Tcp.CommandFailed(cmd: Tcp.Bind) =>
if (portTrials < maxPortTrials) {
portTrials += 1
tcpManager ! Tcp.Bind(self,
new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
backlog = 256)
}
case Tcp.Connected(remote, local) =>
log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
// revoke timeout if all workers have connected.
val workerHandler = context.actorOf(RabitWorkerHandler.props(
remote.getAddress.getHostAddress, numWorkers, self, sender()
), s"ConnectionHandler-${UUID.randomUUID().toString}")
val connection = sender()
connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
}
/**
* Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
* to interact with the tracker interface.
*
* @param trackerMsg control messages sent by RabitTracker class.
*/
private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
trackerMsg match {
case msg: StartTracker =>
maxPortTrials = msg.maxPortTrials
workerConnectionTimeout = msg.connectionTimeout
// if the port number is missing, try binding to a random ephemeral port.
if (msg.addr.getPort == 0) {
tcpManager ! Tcp.Bind(self,
new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
backlog = 256)
} else {
tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
}
sender() ! true
case RequestBoundFuture =>
sender() ! promisedWorkerEnvs.future
case RequestCompletionFuture =>
sender() ! promisedShutdownWorkers.future
case InterruptTracker(e) =>
log.error(e, "Uncaught exception thrown by worker.")
// make sure that waitFor() does not hang indefinitely.
promisedShutdownWorkers.failure(e)
context.stop(self)
}
/**
* Handles messages sent by child actors representing connecting Rabit workers, by brokering
* messages to the dependency resolver, and processing worker commands.
*
* @param workerMsg Message sent by RabitWorkerHandler actors.
*/
private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
case req @ RequestAwaitConnWorkers(_, _) =>
// since the requester may request to connect to other workers
// that have not fully set up, delegate this request to the
// dependency resolver which handles the dependencies properly.
resolver forward req
// ---- Rabit worker commands: start/recover/shutdown/print ----
case WorkerTrackerPrint(_, _, _, msg) =>
log.info(msg.trim)
case WorkerShutdown(rank, _, _) =>
assert(rank >= 0, "Invalid rank.")
assert(!shutdownWorkers.contains(rank))
shutdownWorkers.add(rank)
log.info(s"Received shutdown signal from $rank")
if (shutdownWorkers.size == numWorkers) {
promisedShutdownWorkers.success(shutdownWorkers.size)
context.stop(self)
}
case WorkerRecover(prevRank, worldSize, jobId) =>
assert(prevRank >= 0)
sender() ! linkMap.assignRank(prevRank)
case WorkerStart(rank, worldSize, jobId) =>
assert(worldSize == numWorkers || worldSize == -1,
s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
)
Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
case Success(wkRank) =>
if (jobId != "NULL") {
jobToRankMap.put(jobId, wkRank)
}
val assignedRank = linkMap.assignRank(wkRank)
sender() ! assignedRank
resolver ! assignedRank
log.info("Received start signal from " +
s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
case Failure(ex: IndexOutOfBoundsException) =>
// More than worldSize workers have connected, likely due to executor loss.
// Since Rabit currently does not support crash recovery (because the Allreduce results
// are not cached by the tracker, and because existing workers cannot reestablish
// connections to newly spawned executor/worker), the most reasonble action here is to
// interrupt the tracker immediate with failure state.
log.error("Received invalid start signal from " +
s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
)
promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
" received from worker, likely due to executor loss."))
case Failure(ex) =>
log.error(ex, "Unexpected error")
promisedShutdownWorkers.failure(ex)
}
// ---- Dependency resolving related messages ----
case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
log.info(s"Worker $host (rank: $rank) has started.")
resolver forward msg
startedWorkers.add(rank)
if (startedWorkers.size == numWorkers) {
log.info("All workers have started.")
}
case req @ DropFromWaitingList(_) =>
// all peer workers in dependency link map have connected;
// forward message to resolver to update dependencies.
resolver forward req
case _ =>
}
def receive: Actor.Receive = {
case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
case akka.actor.ReceiveTimeout =>
if (startedWorkers.size < numWorkers) {
promisedShutdownWorkers.failure(
new TimeoutException("Timed out waiting for workers to connect: " +
s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
)
context.stop(self)
}
context.setReceiveTimeout(Duration.Undefined)
}
}
/**
* Resolve the dependency between nodes as they connect to the tracker.
* The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
* and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
* connections by workers, this dependency constraint assumes that a worker node connects first
* is likely to finish setup first.
*/
private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
import RabitWorkerHandler._
context.watch(handler)
case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
// worker nodes that have connected, but have not send WorkerStarted message.
private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
private val startedWorkers = mutable.Set.empty[Int]
// worker nodes that have started, and await for connections.
private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
val connSet = awaitConnWorkers.toMap
.filterKeys(k => linkSet.contains(k))
AwaitingConnections(connSet, linkSet.size - connSet.size)
}
def receive: Actor.Receive = {
// a copy of the AssignedRank message that is also sent to the worker
case AssignedRank(rank, tree_neighbors, ring, parent) =>
// the workers that the worker of given `rank` depends on:
// worker of rank K only depends on workers with rank smaller than K.
val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
.filter{ r => r != -1 && r < rank}
log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
dependencyMap.put(rank, dependentWorkers)
case RequestAwaitConnWorkers(rank, toConnectSet) =>
val promise = Promise[AwaitingConnections]()
assert(dependencyMap.contains(rank))
val updatedDependency = dependencyMap(rank) diff startedWorkers
if (updatedDependency.isEmpty) {
// all dependencies are satisfied
log.debug(s"Rank $rank has all dependencies satisfied.")
promise.success(awaitingWorkers(toConnectSet))
} else {
log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
// promise is pending fulfillment due to unresolved dependency
pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
}
sender() ! promise.future
case WorkerStarted(_, started, awaitingAcceptance) =>
startedWorkers.add(started)
if (awaitingAcceptance > 0) {
awaitConnWorkers.put(started, sender())
}
// remove the started rank from all dependencies.
dependencyMap.remove(started)
dependencyMap.foreach { case (r, dset) =>
val updatedDependency = dset diff startedWorkers
// fulfill the future if all dependencies are met (started.)
if (updatedDependency.isEmpty) {
log.debug(s"Rank $r has all dependencies satisfied.")
pendingFulfillment.remove(r).map{
case Fulfillment(toConnectSet, promise) =>
promise.success(awaitingWorkers(toConnectSet))
}
}
dependencyMap.update(r, updatedDependency)
}
case DropFromWaitingList(rank) =>
assert(awaitConnWorkers.remove(rank).isDefined)
case Terminated(ref) =>
if (ref.equals(handler)) {
context.stop(self)
}
}
}
private[scala] object RabitTrackerHandler {
// Messages sent by RabitTracker to this RabitTrackerHandler actor
trait TrackerControlMessage
case object RequestCompletionFuture extends TrackerControlMessage
case object RequestBoundFuture extends TrackerControlMessage
// Start the Rabit tracker at given socket address awaiting worker connections.
// All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
// shut down due to timeout.
case class StartTracker(addr: InetSocketAddress,
maxPortTrials: Int,
connectionTimeout: Duration) extends TrackerControlMessage
// To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
// driver for the distributed training.
case class InterruptTracker(e: Throwable) extends TrackerControlMessage
def props(numWorkers: Int): Props =
Props(new RabitTrackerHandler(numWorkers))
}

View File

@ -0,0 +1,456 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.handler
import java.nio.{ByteBuffer, ByteOrder}
import akka.io.Tcp
import akka.actor._
import akka.util.ByteString
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Try
/**
* Actor to handle socket communication from worker node.
* To handle fragmentation in received data, this class acts like a FSM
* (finite-state machine) to keep track of the internal states.
*
* @param host IP address of the remote worker
* @param worldSize number of total workers
* @param tracker the RabitTrackerHandler actor reference
*/
private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
connection: ActorRef)
extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
with ActorLogging with Stash {
import RabitWorkerHandler._
import RabitTrackerHelpers._
private[this] var rank: Int = 0
private[this] var port: Int = 0
// indicate if the connection is transient (like "print" or "shutdown")
private[this] var transient: Boolean = false
private[this] var peerClosed: Boolean = false
// number of workers pending acceptance of current worker
private[this] var awaitingAcceptance: Int = 0
private[this] var neighboringWorkers = Set.empty[Int]
// TODO: use a single memory allocation to host all buffers,
// including the transient ones for writing.
private[this] val readBuffer = ByteBuffer.allocate(4096)
.order(ByteOrder.nativeOrder())
// in case the received message is longer than needed,
// stash the spilled over part in this buffer, and send
// to self when transition occurs.
private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
.order(ByteOrder.nativeOrder())
// when setup is complete, need to notify peer handlers
// to reduce the awaiting-connection counter.
private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
private def resetBuffers(): Unit = {
readBuffer.clear()
if (spillOverBuffer.position() > 0) {
spillOverBuffer.flip()
self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
spillOverBuffer.clear()
}
}
private def stashSpillOver(buf: ByteBuffer): Unit = {
if (buf.remaining() > 0) spillOverBuffer.put(buf)
}
def getNeighboringWorkers: Set[Int] = neighboringWorkers
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
val rank = buffer.getInt()
val worldSize = buffer.getInt()
val jobId = buffer.getString
val command = buffer.getString
command match {
case "start" => WorkerStart(rank, worldSize, jobId)
case "shutdown" =>
transient = true
WorkerShutdown(rank, worldSize, jobId)
case "recover" =>
require(rank >= 0, "Invalid rank for recovering worker.")
WorkerRecover(rank, worldSize, jobId)
case "print" =>
transient = true
WorkerTrackerPrint(rank, worldSize, jobId, buffer.getString)
}
}
startWith(AwaitingHandshake, DataStruct())
when(AwaitingHandshake) {
case Event(Tcp.Received(magic), _) =>
assert(magic.length == 4)
val purportedMagic = magic.asNativeOrderByteBuffer.getInt
assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
// echo back the magic number
connection ! Tcp.Write(magic)
goto(AwaitingCommand) using StructTrackerCommand
}
when(AwaitingCommand) {
case Event(Tcp.Received(bytes), validator) =>
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
if (validator.verify(readBuffer)) {
readBuffer.flip()
tracker ! decodeCommand(readBuffer)
stashSpillOver(readBuffer)
}
stay
// when rank for a worker is assigned, send encoded rank information
// back to worker over Tcp socket.
case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
rank = assignedRank
// ranks from the ring
val ringRanks = List(
// ringPrev
if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
// ringNext
if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
)
// update the set of all linked workers to current worker.
neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
// to prevent reading before state transition
connection ! Tcp.SuspendReading
goto(BuildingLinkMap) using StructNodes
}
when(BuildingLinkMap) {
case Event(Tcp.Received(bytes), validator) =>
bytes.asByteBuffers.foreach { buf =>
readBuffer.put(buf)
}
if (validator.verify(readBuffer)) {
readBuffer.flip()
// for a freshly started worker, numConnected should be 0.
val numConnected = readBuffer.getInt()
val toConnectSet = neighboringWorkers.diff(
(0 until numConnected).map { index => readBuffer.getInt() }.toSet)
// check which workers are currently awaiting connections
tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
}
stay
// got a Future from the tracker (resolver) about workers that are
// currently awaiting connections (particularly from this node.)
case Event(future: Future[_], _) =>
// blocks execution until all dependencies for current worker is resolved.
Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
// numNotReachable is the number of workers that currently
// cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
// connections from those currently non-reachable nodes in the future.
case AwaitingConnections(waitConnNodes, numNotReachable) =>
log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
buf.putInt(waitConnNodes.size).putInt(numNotReachable)
buf.flip()
// cache this message until the final state (SetupComplete)
pendingAcknowledgement = Some(AcknowledgeAcceptance(
waitConnNodes, numNotReachable))
connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
if (waitConnNodes.isEmpty) {
connection ! Tcp.SuspendReading
goto(AwaitingErrorCount)
}
else {
waitConnNodes.foreach { case (peerRank, peerRef) =>
peerRef ! RequestWorkerHostPort
}
// a countdown for DivulgedHostPort messages.
stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
}
}
case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
val hostBytes = peerHost.getBytes()
val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
.order(ByteOrder.nativeOrder())
buffer.putInt(peerHost.length).put(hostBytes)
.putInt(peerPort).putInt(peerRank)
buffer.flip()
connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
if (data.counter == 0) {
// to prevent reading before state transition
connection ! Tcp.SuspendReading
goto(AwaitingErrorCount)
}
else {
stay using data.decrement()
}
}
when(AwaitingErrorCount) {
case Event(Tcp.Received(numErrors), _) =>
val buf = numErrors.asNativeOrderByteBuffer
buf.getInt match {
case 0 =>
stashSpillOver(buf)
goto(AwaitingPortNumber)
case _ =>
stashSpillOver(buf)
goto(BuildingLinkMap) using StructNodes
}
}
when(AwaitingPortNumber) {
case Event(Tcp.Received(assignedPort), _) =>
assert(assignedPort.length == 4)
port = assignedPort.asNativeOrderByteBuffer.getInt
log.debug(s"Rank $rank listening @ $host:$port")
// wait until the worker closes connection.
if (peerClosed) goto(SetupComplete) else stay
case Event(Tcp.PeerClosed, _) =>
peerClosed = true
if (port == 0) stay else goto(SetupComplete)
}
when(SetupComplete) {
case Event(ReduceWaitCount(count: Int), _) =>
awaitingAcceptance -= count
// check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
if (awaitingAcceptance == 0 && peerClosed) {
tracker ! DropFromWaitingList(rank)
// no longer needed.
context.stop(self)
}
stay
case Event(AcknowledgeAcceptance(peers, numBad), _) =>
awaitingAcceptance = numBad
tracker ! WorkerStarted(host, rank, awaitingAcceptance)
peers.values.foreach { peer =>
peer ! ReduceWaitCount(1)
}
if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
stay
// can only divulge the complete host and port information
// when this worker is declared fully connected (otherwise
// port information is still missing.)
case Event(RequestWorkerHostPort, _) =>
sender() ! DivulgedWorkerHostPort(rank, host, port)
stay
}
onTransition {
// reset buffer when state transitions as data becomes stale
case _ -> SetupComplete =>
connection ! Tcp.ResumeReading
resetBuffers()
if (pendingAcknowledgement.isDefined) {
self ! pendingAcknowledgement.get
}
case _ =>
connection ! Tcp.ResumeReading
resetBuffers()
}
// default message handler
whenUnhandled {
case Event(Tcp.PeerClosed, _) =>
peerClosed = true
if (transient) context.stop(self)
stay
}
}
private[scala] object RabitWorkerHandler {
val MAGIC_NUMBER = 0xff99
// Finite states of this actor, which acts like a FSM.
// The following states are defined in order as the FSM progresses.
sealed trait State
// [1] Initial state, awaiting worker to send magic number per protocol.
case object AwaitingHandshake extends State
// [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
case object AwaitingCommand extends State
// [3] Brokers connections between workers per ring/tree/parent link map.
case object BuildingLinkMap extends State
// [4] A transient state in which the worker reports the number of errors in establishing
// connections to other peer workers. If no errors, transition to next state.
case object AwaitingErrorCount extends State
// [5] Awaiting the worker to report its port number for accepting connections from peer workers.
// This port number information is later forwarded to linked workers.
case object AwaitingPortNumber extends State
// [6] Final state after completing the setup with the connecting worker. At this stage, the
// worker will have closed the Tcp connection. The actor remains alive to handle messages from
// peer actors representing workers with pending setups.
case object SetupComplete extends State
sealed trait DataField
case object IntField extends DataField
// an integer preceding the actual string
case object StringField extends DataField
case object IntSeqField extends DataField
object DataStruct {
def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
}
// Internal data pertaining to individual state, used to verify the validity of packets sent by
// workers.
case class DataStruct(fields: Seq[DataField], counter: Int) {
/**
* Validate whether the provided buffer is complete (i.e., contains
* all data fields specified for this DataStruct.)
*
* @param buf a byte buffer containing received data.
*/
def verify(buf: ByteBuffer): Boolean = {
if (fields.isEmpty) return true
val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
dupBuf.flip()
Try(fields.foldLeft(true) {
case (complete, field) =>
val remBytes = dupBuf.remaining()
complete && (remBytes > 0) && (remBytes >= (field match {
case IntField =>
dupBuf.position(dupBuf.position() + 4)
4
case StringField =>
val strLen = dupBuf.getInt
dupBuf.position(dupBuf.position() + strLen)
4 + strLen
case IntSeqField =>
val seqLen = dupBuf.getInt
dupBuf.position(dupBuf.position() + seqLen * 4)
4 + seqLen * 4
}))
}).getOrElse(false)
}
def increment(): DataStruct = DataStruct(fields, counter + 1)
def decrement(): DataStruct = DataStruct(fields, counter - 1)
}
val StructNodes = DataStruct(List(IntSeqField), 0)
val StructTrackerCommand = DataStruct(List(
IntField, IntField, StringField, StringField
), 0)
// ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
// RabitWorkerHandler --> RabitTrackerHandler
sealed trait RabitWorkerRequest
// RabitWorkerHandler <-- RabitTrackerHandler
sealed trait RabitWorkerResponse
// Representations of decoded worker commands.
abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
def rank: Int
def worldSize: Int
def jobId: String
def encode: ByteString = {
val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
.order(ByteOrder.nativeOrder())
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
.putInt(command.length).put(command.getBytes()).flip()
ByteString.fromByteBuffer(buf)
}
}
case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("start")
case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("shutdown")
case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("recover")
case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
extends TrackerCommand("print") {
override def encode: ByteString = {
val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
.order(ByteOrder.nativeOrder())
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
.putInt(command.length).put(command.getBytes())
.putInt(msg.length).put(msg.getBytes()).flip()
ByteString.fromByteBuffer(buf)
}
}
// Request to remove the worker of given rank from the list of workers awaiting peer connections.
case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
// Notify the tracker that the worker of given rank has finished setup and started.
case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
extends RabitWorkerRequest
// Request the set of workers to connect to, according to the LinkMap structure.
case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
extends RabitWorkerRequest
// Request, from the tracker, the set of nodes to connect.
case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
extends RabitWorkerResponse
// ---- Messages between ConnectionHandler actors ----
sealed trait IntraWorkerMessage
// Notify neighboring workers to decrease the counter of awaiting workers by `count`.
case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
// Request host and port information from peer ConnectionHandler actors (acting on behave of
// connecting workers.) This message will be brokered by RabitTrackerHandler.
case object RequestWorkerHostPort extends IntraWorkerMessage
// Response to the above request
case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
// A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
extends IntraWorkerMessage
// ---- End of message definitions ----
def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
}
}

View File

@ -0,0 +1,136 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.util
import java.nio.{ByteBuffer, ByteOrder}
/**
* The assigned rank to a connecting Rabit worker, along with the information of the ranks of
* its linked peer workers, which are critical to perform Allreduce.
* When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
* client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
* with this class as a message, which is later encoded to byte string, and sent over socket
* connection to the worker client.
*
* @param rank assigned rank (ranked by worker connection order: first worker connecting to the
* tracker is assigned rank 0, second with rank 1, etc.)
* @param neighbors ranks of neighboring workers in a tree map.
* @param ring ranks of neighboring workers in a ring map.
* @param parent rank of the parent worker.
*/
private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
ring: (Int, Int), parent: Int) {
/**
* Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
* client.
* @param worldSize the number of total distributed workers. Must match `numWorkers` used in
* LinkMap.
* @return a ByteBuffer containing encoded data.
*/
def toByteBuffer(worldSize: Int): ByteBuffer = {
val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
// neighbors in tree structure
neighbors.foreach { n => buffer.putInt(n) }
buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
buffer.flip()
buffer
}
}
private[rabit] class LinkMap(numWorkers: Int) {
private def getNeighbors(rank: Int): Seq[Int] = {
val rank1 = rank + 1
Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
r >= 0 && r < numWorkers
}
}
/**
* Construct a ring structure that tends to share nodes with the tree.
*
* @param treeMap
* @param parentMap
* @param rank
* @return Seq[Int] instance starting from rank.
*/
private def constructShareRing(treeMap: Map[Int, Seq[Int]],
parentMap: Map[Int, Int],
rank: Int = 0): Seq[Int] = {
treeMap(rank).toSet - parentMap(rank) match {
case emptySet if emptySet.isEmpty =>
List(rank)
case connectionSet =>
connectionSet.zipWithIndex.foldLeft(List(rank)) {
case (ringSeq, (v, cnt)) =>
val vConnSeq = constructShareRing(treeMap, parentMap, v)
vConnSeq match {
case vconn if vconn.size == cnt + 1 =>
ringSeq ++ vconn.reverse
case vconn =>
ringSeq ++ vconn
}
}
}
}
/**
* Construct a ring connection used to recover local data.
*
* @param treeMap
* @param parentMap
*/
private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
assert(parentMap(0) == -1)
val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
assert(sharedRing.length == treeMap.size)
(0 until numWorkers).map { r =>
val rPrev = (r + numWorkers - 1) % numWorkers
val rNext = (r + 1) % numWorkers
sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
}.toMap
}
private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
case ((rmap, k), i) =>
val kNext = ringMap_(k)._2
(rmap ++ Map(kNext -> (i + 1)), kNext)
}._1
val ringMap = ringMap_.map {
case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
}
val treeMap = treeMap_.map {
case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
}
val parentMap = parentMap_.map {
case (k, v) if k == 0 =>
rMap_(k) -> -1
case (k, v) =>
rMap_(k) -> rMap_(v)
}
def assignRank(rank: Int): AssignedRank = {
AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
}
}

View File

@ -0,0 +1,39 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.util
import java.nio.{ByteOrder, ByteBuffer}
import akka.util.ByteString
private[rabit] object RabitTrackerHelpers {
implicit class ByteStringHelplers(bs: ByteString) {
// Java by default uses big endian. Enforce native endian so that
// the byte order is consistent with the workers.
def asNativeOrderByteBuffer: ByteBuffer = {
bs.asByteBuffer.order(ByteOrder.nativeOrder())
}
}
implicit class ByteBufferHelpers(buf: ByteBuffer) {
def getString: String = {
val len = buf.getInt()
val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
buf.get(stringBuffer.array(), 0, len)
new String(stringBuffer.array(), "utf-8")
}
}
}

View File

@ -94,6 +94,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
long max_elem = cbatch.offset[cbatch.size]; long max_elem = cbatch.offset[cbatch.size];
cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0); cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0);
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0); cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem) CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
<< "batch.index.length must equal batch.offset.back()"; << "batch.index.length must equal batch.offset.back()";
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem) CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
@ -756,3 +757,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
jenv->SetIntArrayRegion(jout, 0, 1, &out); jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0; return 0;
} }
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitAllreduce
* Signature: (Ljava/nio/ByteBuffer;III)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL);
return 0;
}

View File

@ -303,6 +303,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
(JNIEnv *, jclass, jintArray); (JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitAllreduce
* Signature: (Ljava/nio/ByteBuffer;III)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
(JNIEnv *, jclass, jobject, jint, jint, jint);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -0,0 +1,224 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit
import java.nio.{ByteBuffer, ByteOrder}
import akka.actor.{ActorRef, ActorSystem}
import akka.io.Tcp
import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe}
import akka.util.ByteString
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._
import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpecLike, Matchers}
import scala.concurrent.Promise
object RabitTrackerConnectionHandlerTest {
def intSeqToByteString(seq: Seq[Int]): ByteString = {
val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder())
seq.foreach { i => buf.putInt(i) }
buf.flip()
ByteString.fromByteBuffer(buf)
}
}
@RunWith(classOf[JUnitRunner])
class RabitTrackerConnectionHandlerTest
extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest"))
with FlatSpecLike with Matchers with ImplicitSender {
import RabitTrackerConnectionHandlerTest._
val magic = intSeqToByteString(List(0xff99))
"RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val worldSize = 4
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
// send mock magic number
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
// send mock tracker command in fragments: the handler should be able to handle it.
val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
bufRank.putInt(0).putInt(worldSize).flip()
val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder())
bufJobId.putInt(1).put(Array[Byte]('0')).flip()
val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder())
bufCmd.putInt(5).put("start".getBytes()).flip()
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank))
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId))
// the state should not change for incomplete command data.
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
// send the last fragment, and expect message at tracker actor.
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
val linkMap = new LinkMap(worldSize)
val assignedRank = linkMap.assignRank(0)
trackerProbe.reply(assignedRank)
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
assignedRank.toByteBuffer(worldSize)
)))
// reading should be suspended upon transitioning to BuildingLinkMap
connProbe.expectMsg(Tcp.SuspendReading)
// state should transition with according state data changes.
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
connProbe.expectMsg(Tcp.ResumeReading)
// since the connection handler in test has rank 0, it will not have any nodes to connect to.
fsm ! Tcp.Received(intSeqToByteString(List(0)))
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
// return mock response to the connection handler
val awaitConnPromise = Promise[AwaitingConnections]()
awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef],
fsm.underlyingActor.getNeighboringWorkers.size
))
fsm ! awaitConnPromise.future
connProbe.expectMsg(Tcp.Write(
intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size))
))
connProbe.expectMsg(Tcp.SuspendReading)
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount
connProbe.expectMsg(Tcp.ResumeReading)
// send mock error count (0)
fsm ! Tcp.Received(intSeqToByteString(List(0)))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
connProbe.expectMsg(Tcp.ResumeReading)
// simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events.
fsm ! Tcp.PeerClosed
// state should not transition
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
fsm ! Tcp.Received(intSeqToByteString(List(32768)))
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
connProbe.expectMsg(Tcp.ResumeReading)
trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2))
val handlerStopProbe = TestProbe()
handlerStopProbe watch fsm
// simulate connections from other workers by mocking ReduceWaitCount commands
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0))
handlerStopProbe.expectTerminated(fsm)
// all done.
}
it should "forward print command to tracker" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
fsm ! Tcp.Received(printCmd.encode)
trackerProbe.expectMsg(printCmd)
}
it should "handle spill-over Tcp data correctly between state transition" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val worldSize = 4
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
// send mock magic number
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
// send mock tracker command in fragments: the handler should be able to handle it.
val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder())
bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0'))
.putInt(5).put("start".getBytes())
// spilled-over data
.putInt(0).flip()
// send data with 4 extra bytes corresponding to the next state.
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
val linkMap = new LinkMap(worldSize)
val assignedRank = linkMap.assignRank(0)
trackerProbe.reply(assignedRank)
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
assignedRank.toByteBuffer(worldSize)
)))
// reading should be suspended upon transitioning to BuildingLinkMap
connProbe.expectMsg(Tcp.SuspendReading)
// state should transition with according state data changes.
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
connProbe.expectMsg(Tcp.ResumeReading)
// the handler should be able to handle spill-over data, and stash it until state transition.
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
}
}