diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala index 366bf7b3d..58afd82e1 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala @@ -57,7 +57,7 @@ object CustomObjective { case e: XGBoostError => logger.error(e) null - case _ => + case _: Throwable => null } val grad = new Array[Float](nrow) diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala index 4c6adee99..9ac8c2668 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala @@ -85,7 +85,7 @@ object XGBoost { def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int): XGBoostModel = { val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism) - if (tracker.start()) { + if (tracker.start(0L)) { dtrain .mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs)) .reduce((x, y) => x).collect().head diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index f4a05cc1d..bf22f7fcc 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -16,11 +16,10 @@ package ml.dmlc.xgboost4j.scala.spark -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ListBuffer - -import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError} +import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker} +import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.{FSDataInputStream, Path} @@ -30,6 +29,25 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.{SparkContext, TaskContext} +import scala.concurrent.duration.{Duration, MILLISECONDS} + +object TrackerConf { + def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python") +} + +/** + * Rabit tracker configurations. + * @param workerConnectionTimeout The timeout for all workers to connect to the tracker. + * Set timeout length to zero to disable timeout. + * Use a finite, non-zero timeout value to prevent tracker from + * hanging indefinitely (supported by "scala" implementation only.) + * @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of + * the Python Rabit tracker (in dmlc_core), whereas the latter is implemented + * in Scala without Python components, and with full support of timeouts. + * The Scala implementation is currently experimental, use at your own risk. + */ +case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String) + object XGBoost extends Serializable { private val logger = LogFactory.getLog("XGBoostSpark") @@ -80,7 +98,7 @@ object XGBoost extends Serializable { private[spark] def buildDistributedBoosters( trainingSet: RDD[MLLabeledPoint], xgBoostConfMap: Map[String, Any], - rabitEnv: mutable.Map[String, String], + rabitEnv: java.util.Map[String, String], numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = { import DataUtils._ @@ -92,7 +110,7 @@ object XGBoost extends Serializable { partitionedTrainingSet.mapPartitions { trainingSamples => rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) - Rabit.init(rabitEnv.asJava) + Rabit.init(rabitEnv) var booster: Booster = null if (trainingSamples.hasNext) { val cacheFileName: String = { @@ -211,9 +229,21 @@ object XGBoost extends Serializable { overridedParams } - private def startTracker(nWorkers: Int): RabitTracker = { - val tracker = new RabitTracker(nWorkers) - require(tracker.start(), "FAULT: Failed to start tracker") + private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { + val tracker: IRabitTracker = trackerConf.trackerImpl match { + case "scala" => new RabitTracker(nWorkers) + case "python" => new PyRabitTracker(nWorkers) + case _ => new PyRabitTracker(nWorkers) + } + + val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) { + trackerConf.workerConnectionTimeout.toMillis + } else { + // 0 == Duration.Inf + 0L + } + + require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker") tracker } @@ -227,7 +257,7 @@ object XGBoost extends Serializable { * @param obj the user-defined objective function, null by default * @param eval the user-defined evaluation function, null by default * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as - * true, the user may save the RAM cost for running XGBoost within Spark + * true, the user may save the RAM cost for running XGBoost within Spark * @param missing the value represented the missing value in the dataset * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed * @return XGBoostModel when successful training @@ -243,19 +273,26 @@ object XGBoost extends Serializable { " you have to specify the objective type as classification or regression with a" + " customized objective function") } - val tracker = startTracker(nWorkers) + val trackerConf = params.get("tracker_conf") match { + case None => TrackerConf() + case Some(conf: TrackerConf) => conf + case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " + + "instance of TrackerConf.") + } + val tracker = startTracker(nWorkers, trackerConf) val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext) val boosters = buildDistributedBoosters(trainingData, overridedConfMap, - tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing) + tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing) val sparkJobThread = new Thread() { override def run() { // force the job boosters.foreachPartition(() => _) } } + sparkJobThread.setUncaughtExceptionHandler(tracker) sparkJobThread.start() val isClsTask = isClassificationTask(params) - val trackerReturnVal = tracker.waitFor() + val trackerReturnVal = tracker.waitFor(0L) logger.info(s"Rabit returns with exit code $trackerReturnVal") postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread, isClsTask) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 8d0f60cfd..212daadbc 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -16,9 +16,12 @@ package ml.dmlc.xgboost4j.scala.spark.params +import ml.dmlc.xgboost4j.scala.spark.TrackerConf import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import org.apache.spark.ml.param._ +import scala.concurrent.duration.{Duration, NANOSECONDS} + trait GeneralParams extends Params { /** @@ -69,7 +72,38 @@ trait GeneralParams extends Params { */ val missing = new FloatParam(this, "missing", "the value treated as missing") + /** + * Rabit tracker configurations. The parameter must be provided as an instance of the + * TrackerConf class, which has the following definition: + * + * case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration, + * trackerImpl: String) + * + * See below for detailed explanations. + * + * - trackerImpl: Select the implementation of Rabit tracker. + * default: "python" + * + * Choice between "python" or "scala". The former utilizes the Java wrapper of the + * Python Rabit tracker (in dmlc_core), and does not support timeout settings. + * The "scala" version removes Python components, and fully supports timeout settings. + * + * - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker. + * default: 0 millisecond (no timeout) + * + * The timeout value should take the time of data loading and pre-processing into account, + * due to the lazy execution of Spark's operations. Alternatively, you may force Spark to + * perform data transformation before calling XGBoost.train(), so that this timeout truly + * reflects the connection delay. Set a reasonable timeout value to prevent model + * training/testing from hanging indefinitely, possible due to network issues. + * Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf). + * Ignored if the tracker implementation is "python". + */ + val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations") + setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1, useExternalMemory -> false, silent -> 0, - customObj -> null, customEval -> null, missing -> Float.NaN) + customObj -> null, customEval -> null, missing -> Float.NaN, + trackerConf -> TrackerConf() + ) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitTrackerRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitTrackerRobustnessSuite.scala new file mode 100644 index 000000000..2d1dc2711 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitTrackerRobustnessSuite.scala @@ -0,0 +1,169 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker} +import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker} +import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus +import org.apache.spark.{SparkConf, SparkContext} +import org.scalatest.FunSuite + + +class RabitTrackerRobustnessSuite extends FunSuite with Utils { + test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") { + /* + Deliberately create new instances of SparkContext in each unit test to avoid reusing the + same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes + by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore + corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent + tests on a reentrant thread will crash the entire Spark application, an undesired side-effect + that should be avoided. + */ + val sparkConf = new SparkConf().setMaster("local[*]") + .setAppName("XGBoostSuite").set("spark.driver.memory", "512m") + implicit val sparkContext = new SparkContext(sparkConf) + sparkContext.setLogLevel("ERROR") + + val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache() + + val tracker = new PyRabitTracker(numWorkers) + tracker.start(0) + val trackerEnvs = tracker.getWorkerEnvs + + val workerCount: Int = numWorkers + /* + Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the + last created worker. A cascading event chain will be triggered once the RuntimeException is + thrown: the thread running the dummy spark job (sparkThread) catches the exception and + delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself. + + The Java RabitTracker class reacts to exceptions by killing the spawned process running + the Python tracker. If at least one Rabit worker has yet connected to the tracker before + it is killed, the resulted connection failure will trigger the Rabit worker to call + "exit(-1);" in the native C++ code, effectively ending the dummy Spark task. + + In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are + isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks + running in separate containers. However, as unit tests are run in Spark local mode, in which + tasks are executed by threads belonging to the same process, one thread calling "exit(-1);" + ultimately kills the entire process, which also happens to host the Spark driver, causing + the entire Spark application to crash. + + To prevent unit tests from crashing, deterministic delays were introduced to make sure that + the exception is thrown at last, ideally after all worker connections have been established. + For the same reason, the Java RabitTracker class delays the killing of the Python tracker + process to ensure that pending worker connections are handled. + */ + val dummyTasks = rdd.mapPartitions { iter => + Rabit.init(trackerEnvs) + val index = iter.next() + Thread.sleep(100 + index * 10) + if (index == workerCount) { + // kill the worker by throwing an exception + throw new RuntimeException("Worker exception.") + } + Rabit.shutdown() + Iterator(index) + }.cache() + + val sparkThread = new Thread() { + override def run(): Unit = { + // forces a Spark job. + dummyTasks.foreachPartition(() => _) + } + } + + sparkThread.setUncaughtExceptionHandler(tracker) + sparkThread.start() + assert(tracker.waitFor(0) != 0) + sparkContext.stop() + } + + test("test Scala RabitTracker's exception handling: it should not hang forever.") { + val sparkConf = new SparkConf().setMaster("local[*]") + .setAppName("XGBoostSuite").set("spark.driver.memory", "512m") + implicit val sparkContext = new SparkContext(sparkConf) + sparkContext.setLogLevel("ERROR") + + val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache() + + val tracker = new ScalaRabitTracker(numWorkers) + tracker.start(0) + val trackerEnvs = tracker.getWorkerEnvs + + val workerCount: Int = numWorkers + val dummyTasks = rdd.mapPartitions { iter => + Rabit.init(trackerEnvs) + val index = iter.next() + Thread.sleep(100 + index * 10) + if (index == workerCount) { + // kill the worker by throwing an exception + throw new RuntimeException("Worker exception.") + } + Rabit.shutdown() + Iterator(index) + }.cache() + + val sparkThread = new Thread() { + override def run(): Unit = { + // forces a Spark job. + dummyTasks.foreachPartition(() => _) + } + } + sparkThread.setUncaughtExceptionHandler(tracker) + sparkThread.start() + assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode) + sparkContext.stop() + } + + test("test Scala RabitTracker's workerConnectionTimeout") { + val sparkConf = new SparkConf().setMaster("local[*]") + .setAppName("XGBoostSuite").set("spark.driver.memory", "512m") + implicit val sparkContext = new SparkContext(sparkConf) + sparkContext.setLogLevel("ERROR") + + val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache() + + val tracker = new ScalaRabitTracker(numWorkers) + tracker.start(500) + val trackerEnvs = tracker.getWorkerEnvs + + val dummyTasks = rdd.mapPartitions { iter => + val index = iter.next() + // simulate that the first worker cannot connect to tracker due to network issues. + if (index != 1) { + Rabit.init(trackerEnvs) + Thread.sleep(1000) + Rabit.shutdown() + } + + Iterator(index) + }.cache() + + val sparkThread = new Thread() { + override def run(): Unit = { + // forces a Spark job. + dummyTasks.foreachPartition(() => _) + } + } + sparkThread.setUncaughtExceptionHandler(tracker) + sparkThread.start() + // should fail due to connection timeout + assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode) + sparkContext.stop() + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 5faed7234..1874a3b6d 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -17,18 +17,60 @@ package ml.dmlc.xgboost4j.scala.spark import java.nio.file.Files +import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} import scala.collection.mutable.ListBuffer import scala.util.Random - -import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} +import scala.concurrent.duration._ +import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.DMatrix +import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import org.apache.spark.SparkContext import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors} +import org.apache.spark.ml.linalg.{Vectors, Vector => SparkVector} import org.apache.spark.rdd.RDD class XGBoostGeneralSuite extends SharedSparkContext with Utils { + test("test Rabit allreduce to validate Scala-implemented Rabit tracker") { + val vectorLength = 100 + val rdd = sc.parallelize( + (1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache() + + val tracker = new RabitTracker(numWorkers) + tracker.start(0) + val trackerEnvs = tracker.getWorkerEnvs + val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]() + + val rawData = rdd.mapPartitions { iter => + Iterator(iter.toArray) + }.collect() + + val maxVec = (0 until vectorLength).toArray.map { j => + (0 until numWorkers).toArray.map { i => rawData(i)(j) }.max + } + + val allReduceResults = rdd.mapPartitions { iter => + Rabit.init(trackerEnvs) + val arr = iter.toArray + val results = Rabit.allReduce(arr, Rabit.OpType.MAX) + Rabit.shutdown() + Iterator(results) + }.cache() + + val sparkThread = new Thread() { + override def run(): Unit = { + allReduceResults.foreachPartition(() => _) + val byPartitionResults = allReduceResults.collect() + assert(byPartitionResults(0).length == vectorLength) + collectedAllReduceResults.put(byPartitionResults(0)) + } + } + sparkThread.start() + assert(tracker.waitFor(0L) == 0) + sparkThread.join() + + assert(collectedAllReduceResults.poll().sameElements(maxVec)) + } test("build RDD containing boosters with the specified worker number") { val trainingRDD = buildTrainingRDD(sc) @@ -36,7 +78,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { trainingRDD, List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic").toMap, - new scala.collection.mutable.HashMap[String, String], + new java.util.HashMap[String, String](), numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true) val boosterCount = boosterRDD.count() assert(boosterCount === 2) @@ -59,6 +101,21 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { cleanExternalCache("XGBoostSuite") } + test("training with Scala-implemented Rabit tracker") { + val eval = new EvalError() + val trainingRDD = buildTrainingRDD(sc) + val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator + import DataUtils._ + val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) + val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", + "tracker_conf" -> TrackerConf(1 minute, "scala")).toMap + val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, + nWorkers = numWorkers, useExternalMemory = true) + assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), + testSetDMatrix) < 0.1) + } + test("test with dense vectors containing missing value") { def buildDenseRDD(): RDD[LabeledPoint] = { val nrow = 100 diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index 0efddd5dd..00cfed904 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -108,5 +108,17 @@ 4.11 test + + com.typesafe.akka + akka-actor_${scala.binary.version} + 2.3.11 + compile + + + com.typesafe.akka + akka-testkit_${scala.binary.version} + 2.3.11 + test + diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java new file mode 100644 index 000000000..2a2fcd423 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java @@ -0,0 +1,43 @@ +package ml.dmlc.xgboost4j.java; + +import java.util.Map; +import java.util.concurrent.TimeUnit; + +/** + * Interface for Rabit tracker implementations with three public methods: + * + * - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given + * timeout value (in milliseconds.) + * - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients. + * - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout` + * milliseconds. + * + * Each implementation is expected to implement a callback function + * + * public void uncaughtException(Threat t, Throwable e) { ... } + * + * to interrupt waitFor() in order to prevent the tracker from hanging indefinitely. + * + * The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and + * brokers connections between workers. + */ +public interface IRabitTracker extends Thread.UncaughtExceptionHandler { + public enum TrackerStatus { + SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3); + + private int statusCode; + + TrackerStatus(int statusCode) { + this.statusCode = statusCode; + } + + public int getStatusCode() { + return this.statusCode; + } + } + + Map getWorkerEnvs(); + boolean start(long workerConnectionTimeout); + // taskExecutionTimeout has no effect in current version of XGBoost. + int waitFor(long taskExecutionTimeout); +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java index 3429dc3dd..6e996494b 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java @@ -2,6 +2,9 @@ package ml.dmlc.xgboost4j.java; import java.io.IOException; import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; import java.util.Map; import org.apache.commons.logging.Log; @@ -22,6 +25,42 @@ public class Rabit { } } + public enum OpType implements Serializable { + MAX(0), MIN(1), SUM(2), BITWISE_OR(3); + + private int op; + + public int getOperand() { + return this.op; + } + + OpType(int op) { + this.op = op; + } + } + + public enum DataType implements Serializable { + CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4), + LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8), + LONGLONG(8, 8), ULONGLONG(9, 8); + + private int enumOp; + private int size; + + public int getEnumOp() { + return this.enumOp; + } + + public int getSize() { + return this.size; + } + + DataType(int enumOp, int size) { + this.enumOp = enumOp; + this.size = size; + } + } + private static void checkCall(int ret) throws XGBoostError { if (ret != 0) { throw new XGBoostError(XGBoostJNI.XGBGetLastError()); @@ -92,4 +131,30 @@ public class Rabit { checkCall(XGBoostJNI.RabitGetWorldSize(out)); return out[0]; } + + /** + * perform Allreduce on distributed float vectors using operator op. + * This implementation of allReduce does not support customized prepare function callback in the + * native code, as this function is meant for testing purposes only (to test the Rabit tracker.) + * + * @param elements local elements on distributed workers. + * @param op operator used for Allreduce. + * @return All-reduced float elements according to the given operator. + */ + public static float[] allReduce(float[] elements, OpType op) { + DataType dataType = DataType.FLOAT; + ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length) + .order(ByteOrder.nativeOrder()); + + for (float el : elements) { + buffer.putFloat(el); + } + buffer.flip(); + + XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand()); + float[] results = new float[elements.length]; + buffer.asFloatBuffer().get(results); + + return results; + } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index bc419c564..d2008cd7f 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -5,15 +5,24 @@ package ml.dmlc.xgboost4j.java; import java.io.*; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; /** - * Distributed RabitTracker, need to be started on driver code before running distributed jobs. + * Java implementation of the Rabit tracker to coordinate distributed workers. + * As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both + * start() and waitFor() methods (i.e., the timeout is infinite.) + * + * For systems lacking Python environment, or for timeout functionality, consider using the Scala + * Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and + * provides timeout support. + * + * The tracker must be started on driver node before running distributed jobs. */ -public class RabitTracker { +public class RabitTracker implements IRabitTracker { // Maybe per tracker logger? private static final Log logger = LogFactory.getLog(RabitTracker.class); // tracker python file. @@ -69,7 +78,6 @@ public class RabitTracker { } } - public RabitTracker(int numWorkers) throws XGBoostError { if (numWorkers < 1) { @@ -78,6 +86,17 @@ public class RabitTracker { this.numWorkers = numWorkers; } + public void uncaughtException(Thread t, Throwable e) { + logger.error("Uncaught exception thrown by worker:", e); + try { + Thread.sleep(5000L); + } catch (InterruptedException ex) { + logger.error(ex); + } finally { + trackerProcess.get().destroy(); + } + } + /** * Get environments that can be used to pass to worker. * @return The environment settings. @@ -126,7 +145,13 @@ public class RabitTracker { } } - public boolean start() { + public boolean start(long timeout) { + if (timeout > 0L) { + logger.warn("Python RabitTracker does not support timeout. " + + "The tracker will wait for all workers to connect indefinitely, unless " + + "it is interrupted manually. Use the Scala RabitTracker for timeout support."); + } + if (startTrackerProcess()) { logger.debug("Tracker started, with env=" + envs.toString()); System.out.println("Tracker started, with env=" + envs.toString()); @@ -142,7 +167,14 @@ public class RabitTracker { } } - public int waitFor() { + public int waitFor(long timeout) { + if (timeout > 0L) { + logger.warn("Python RabitTracker does not support timeout. " + + "The tracker will wait for either all workers to finish tasks and send " + + "shutdown signal, or manual interruptions. " + + "Use the Scala RabitTracker for timeout support."); + } + try { trackerProcess.get().waitFor(); int returnVal = trackerProcess.get().exitValue(); @@ -153,7 +185,7 @@ public class RabitTracker { // we should not get here as RabitTracker is accessed in the main thread e.printStackTrace(); logger.error("the RabitTracker thread is terminated unexpectedly"); - return 1; + return TrackerStatus.INTERRUPTED.getStatusCode(); } } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 4ecef65a7..630c61647 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -16,6 +16,8 @@ package ml.dmlc.xgboost4j.java; +import java.nio.ByteBuffer; + /** * xgboost JNI functions * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster @@ -97,4 +99,9 @@ class XGBoostJNI { public final static native int RabitGetRank(int[] out); public final static native int RabitGetWorldSize(int[] out); public final static native int RabitVersionNumber(int[] out); + + // Perform Allreduce operation on data in sendrecvbuf. + // This JNI function does not support the callback function for data preparation yet. + final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count, + int enum_dtype, int enum_op); } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala new file mode 100644 index 000000000..d6ca42e75 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala @@ -0,0 +1,190 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.rabit + +import java.net.{InetAddress, InetSocketAddress} + +import akka.actor.ActorSystem +import akka.pattern.ask +import ml.dmlc.xgboost4j.java.IRabitTracker +import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler + +import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} +import scala.util.{Failure, Success, Try} + +/** + * Scala implementation of the Rabit tracker interface without Python dependency. + * The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker + * (and thus any distributed tasks) to hang indefinitely due to network issues or worker node + * failures. + * + * Note that this implementation is currently experimental, and should be used at your own risk. + * + * Example usage: + * {{{ + * import scala.concurrent.duration._ + * + * val tracker = new RabitTracker(32) + * // allow up to 10 minutes for all workers to connect to the tracker. + * tracker.start(10 minutes) + * + * /* ... + * launching workers in parallel + * ... + * */ + * + * // wait for worker execution up to 6 hours. + * // providing a finite timeout prevents a long-running task from hanging forever in + * // catastrophic events, like the loss of an executor during model training. + * tracker.waitFor(6 hours) + * }}} + * + * @param numWorkers Number of distributed workers from which the tracker expects connections. + * @param port The minimum port number that the tracker binds to. + * If port is omitted, or given as None, a random ephemeral port is chosen at runtime. + * @param maxPortTrials The maximum number of trials of socket binding, by sequentially + * increasing the port number. + */ +private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None, + maxPortTrials: Int = 1000) + extends IRabitTracker { + + import scala.collection.JavaConverters._ + + require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).") + + val system = ActorSystem.create("RabitTracker") + val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler") + implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds) + private[this] val tcpBindingTimeout: Duration = 1 minute + + var workerEnvs: Map[String, String] = Map.empty + + override def uncaughtException(t: Thread, e: Throwable): Unit = { + handler ? RabitTrackerHandler.InterruptTracker(e) + } + + /** + * Start the Rabit tracker. + * + * @param timeout The timeout for awaiting connections from worker nodes. + * Note that when used in Spark applications, because all Spark transformations are + * lazily executed, the I/O time for loading RDDs/DataFrames from external sources + * (local dist, HDFS, S3 etc.) must be taken into account for the timeout value. + * If the timeout value is too small, the Rabit tracker will likely timeout before workers + * establishing connections to the tracker, due to the overhead of loading data. + * Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver + * running it) from hanging indefinitely due to worker connection issues (e.g. firewall.) + * @return Boolean flag indicating if the Rabit tracker starts successfully. + */ + private def start(timeout: Duration): Boolean = { + handler ? RabitTrackerHandler.StartTracker( + new InetSocketAddress(InetAddress.getLocalHost, port.getOrElse(0)), maxPortTrials, timeout) + + // block by waiting for the actor to bind to a port + Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration) + .asInstanceOf[Future[Map[String, String]]]) match { + case Success(futurePortBound) => + // The success of the Future is contingent on binding to an InetSocketAddress. + val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess + if (isBound) { + workerEnvs = Await.result(futurePortBound, 0 nano) + } + isBound + case Failure(ex: Throwable) => + false + } + } + + /** + * Start the Rabit tracker. + * + * @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker + * connections. If a non-positive value is provided, the tracker + * waits for incoming worker connections indefinitely. + * @return Boolean flag indicating if the Rabit tracker starts successfully. + */ + def start(connectionTimeoutMillis: Long): Boolean = { + if (connectionTimeoutMillis <= 0) { + start(Duration.Inf) + } else { + start(Duration.fromNanos(connectionTimeoutMillis * 1e6)) + } + } + + /** + * Get a Map of necessary environment variables to initiate Rabit workers. + * + * @return HashMap containing tracker information. + */ + def getWorkerEnvs: java.util.Map[String, String] = { + new java.util.HashMap((workerEnvs ++ Map( + "DMLC_NUM_WORKER" -> numWorkers.toString, + "DMLC_NUM_SERVER" -> "0" + )).asJava) + } + + /** + * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds. + * This method blocks until timeout or task completion. + * + * @param atMost the maximum execution time for the workers. By default, + * the tracker waits for the workers indefinitely. + * @return 0 if the tasks complete successfully, and non-zero otherwise. + */ + private def waitFor(atMost: Duration): Int = { + // request the completion Future from the tracker actor + Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration) + .asInstanceOf[Future[Int]]) match { + case Success(futureCompleted) => + // wait for all workers to complete synchronously. + val statusCode = Try(Await.result(futureCompleted, atMost)) match { + case Success(n) if n == numWorkers => + IRabitTracker.TrackerStatus.SUCCESS.getStatusCode + case Success(n) if n < numWorkers => + IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode + case Failure(e) => + IRabitTracker.TrackerStatus.FAILURE.getStatusCode + } + system.shutdown() + statusCode + case Failure(ex: Throwable) => + if (!system.isTerminated) { + system.shutdown() + } + IRabitTracker.TrackerStatus.FAILURE.getStatusCode + } + } + + /** + * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds. + * This method blocks until timeout or task completion. + * + * @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a + * non-positive number is given, the tracker waits indefinitely. + * @return 0 if the tasks complete successfully, and non-zero otherwise + */ + def waitFor(atMostMillis: Long): Int = { + if (atMostMillis <= 0) { + waitFor(Duration.Inf) + } else { + waitFor(Duration.fromNanos(atMostMillis * 1e6)) + } + } +} + diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala new file mode 100644 index 000000000..8b1c25a34 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala @@ -0,0 +1,362 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.rabit.handler + +import java.net.InetSocketAddress +import java.util.UUID + +import scala.concurrent.duration._ +import scala.collection.mutable +import scala.concurrent.{Promise, TimeoutException} +import akka.io.{IO, Tcp} +import akka.actor._ +import ml.dmlc.xgboost4j.java.XGBoostError +import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap} + +import scala.util.{Failure, Random, Success, Try} + +/** The Akka actor for handling and coordinating Rabit worker connections. + * This is the main actor for handling socket connections, interacting with the synchronous + * tracker interface, and resolving tree/ring/parent dependencies between workers. + * + * @param numWorkers Number of workers to track. + */ +private[scala] class RabitTrackerHandler(numWorkers: Int) + extends Actor with ActorLogging { + + import context.system + import RabitWorkerHandler._ + import RabitTrackerHandler._ + + private[this] val promisedWorkerEnvs = Promise[Map[String, String]]() + private[this] val promisedShutdownWorkers = Promise[Int]() + private[this] val tcpManager = IO(Tcp) + + // resolves worker connection dependency. + val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver") + + // workers that have sent "shutdown" signal + private[this] val shutdownWorkers = mutable.Set.empty[Int] + private[this] val jobToRankMap = mutable.HashMap.empty[String, Int] + private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String] + private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*) + private[this] var maxPortTrials = 0 + private[this] var workerConnectionTimeout: Duration = Duration.Inf + private[this] var portTrials = 0 + private[this] val startedWorkers = mutable.Set.empty[Int] + + val linkMap = new LinkMap(numWorkers) + + def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = { + rank match { + case r if r >= 0 => Some(r) + case _ => + jobId match { + case "NULL" => None + case jid => jobToRankMap.get(jid) + } + } + } + + /** + * Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled + * by the RabitWorkerHandler. + * + * @param event Generic Tcp.Event + */ + private def handleTcpEvents(event: Tcp.Event): Unit = event match { + case Tcp.Bound(local) => + // expect all workers to connect within timeout + log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}") + log.info(s"Worker connection timeout is $workerConnectionTimeout.") + + context.setReceiveTimeout(workerConnectionTimeout) + promisedWorkerEnvs.success(Map( + "DMLC_TRACKER_URI" -> local.getAddress.getHostAddress, + "DMLC_TRACKER_PORT" -> local.getPort.toString, + // not required because the world size will be communicated to the + // worker node after the rank is assigned. + "rabit_world_size" -> numWorkers.toString + )) + + case Tcp.CommandFailed(cmd: Tcp.Bind) => + if (portTrials < maxPortTrials) { + portTrials += 1 + tcpManager ! Tcp.Bind(self, + new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1), + backlog = 256) + } + + case Tcp.Connected(remote, local) => + log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}") + // revoke timeout if all workers have connected. + val workerHandler = context.actorOf(RabitWorkerHandler.props( + remote.getAddress.getHostAddress, numWorkers, self, sender() + ), s"ConnectionHandler-${UUID.randomUUID().toString}") + val connection = sender() + connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true) + + actorRefToHost.put(workerHandler, remote.getAddress.getHostName) + } + + /** + * Handles external tracker control messages sent by RabitTracker (usually in ask patterns) + * to interact with the tracker interface. + * + * @param trackerMsg control messages sent by RabitTracker class. + */ + private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit = + trackerMsg match { + + case msg: StartTracker => + maxPortTrials = msg.maxPortTrials + workerConnectionTimeout = msg.connectionTimeout + + // if the port number is missing, try binding to a random ephemeral port. + if (msg.addr.getPort == 0) { + tcpManager ! Tcp.Bind(self, + new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768), + backlog = 256) + } else { + tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256) + } + sender() ! true + + case RequestBoundFuture => + sender() ! promisedWorkerEnvs.future + + case RequestCompletionFuture => + sender() ! promisedShutdownWorkers.future + + case InterruptTracker(e) => + log.error(e, "Uncaught exception thrown by worker.") + // make sure that waitFor() does not hang indefinitely. + promisedShutdownWorkers.failure(e) + context.stop(self) + } + + /** + * Handles messages sent by child actors representing connecting Rabit workers, by brokering + * messages to the dependency resolver, and processing worker commands. + * + * @param workerMsg Message sent by RabitWorkerHandler actors. + */ + private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match { + case req @ RequestAwaitConnWorkers(_, _) => + // since the requester may request to connect to other workers + // that have not fully set up, delegate this request to the + // dependency resolver which handles the dependencies properly. + resolver forward req + + // ---- Rabit worker commands: start/recover/shutdown/print ---- + case WorkerTrackerPrint(_, _, _, msg) => + log.info(msg.trim) + + case WorkerShutdown(rank, _, _) => + assert(rank >= 0, "Invalid rank.") + assert(!shutdownWorkers.contains(rank)) + shutdownWorkers.add(rank) + + log.info(s"Received shutdown signal from $rank") + + if (shutdownWorkers.size == numWorkers) { + promisedShutdownWorkers.success(shutdownWorkers.size) + context.stop(self) + } + + case WorkerRecover(prevRank, worldSize, jobId) => + assert(prevRank >= 0) + sender() ! linkMap.assignRank(prevRank) + + case WorkerStart(rank, worldSize, jobId) => + assert(worldSize == numWorkers || worldSize == -1, + s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)." + ) + + Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match { + case Success(wkRank) => + if (jobId != "NULL") { + jobToRankMap.put(jobId, wkRank) + } + + val assignedRank = linkMap.assignRank(wkRank) + sender() ! assignedRank + resolver ! assignedRank + + log.info("Received start signal from " + + s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]") + + case Failure(ex: IndexOutOfBoundsException) => + // More than worldSize workers have connected, likely due to executor loss. + // Since Rabit currently does not support crash recovery (because the Allreduce results + // are not cached by the tracker, and because existing workers cannot reestablish + // connections to newly spawned executor/worker), the most reasonble action here is to + // interrupt the tracker immediate with failure state. + log.error("Received invalid start signal from " + + s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started." + ) + promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" + + " received from worker, likely due to executor loss.")) + + case Failure(ex) => + log.error(ex, "Unexpected error") + promisedShutdownWorkers.failure(ex) + } + + + // ---- Dependency resolving related messages ---- + case msg @ WorkerStarted(host, rank, awaitingAcceptance) => + log.info(s"Worker $host (rank: $rank) has started.") + resolver forward msg + + startedWorkers.add(rank) + if (startedWorkers.size == numWorkers) { + log.info("All workers have started.") + } + + case req @ DropFromWaitingList(_) => + // all peer workers in dependency link map have connected; + // forward message to resolver to update dependencies. + resolver forward req + + case _ => + } + + def receive: Actor.Receive = { + case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent) + case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg) + case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg) + + case akka.actor.ReceiveTimeout => + if (startedWorkers.size < numWorkers) { + promisedShutdownWorkers.failure( + new TimeoutException("Timed out waiting for workers to connect: " + + s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.") + ) + context.stop(self) + } + + context.setReceiveTimeout(Duration.Undefined) + } +} + +/** + * Resolve the dependency between nodes as they connect to the tracker. + * The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap + * and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of + * connections by workers, this dependency constraint assumes that a worker node connects first + * is likely to finish setup first. + */ +private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging { + import RabitWorkerHandler._ + + context.watch(handler) + + case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections]) + + // worker nodes that have connected, but have not send WorkerStarted message. + private val dependencyMap = mutable.Map.empty[Int, Set[Int]] + private val startedWorkers = mutable.Set.empty[Int] + // worker nodes that have started, and await for connections. + private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef] + private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment] + + def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = { + val connSet = awaitConnWorkers.toMap + .filterKeys(k => linkSet.contains(k)) + AwaitingConnections(connSet, linkSet.size - connSet.size) + } + + def receive: Actor.Receive = { + // a copy of the AssignedRank message that is also sent to the worker + case AssignedRank(rank, tree_neighbors, ring, parent) => + // the workers that the worker of given `rank` depends on: + // worker of rank K only depends on workers with rank smaller than K. + val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2)) + .filter{ r => r != -1 && r < rank} + + log.debug(s"Rank $rank connected, dependencies: $dependentWorkers") + dependencyMap.put(rank, dependentWorkers) + + case RequestAwaitConnWorkers(rank, toConnectSet) => + val promise = Promise[AwaitingConnections]() + + assert(dependencyMap.contains(rank)) + + val updatedDependency = dependencyMap(rank) diff startedWorkers + if (updatedDependency.isEmpty) { + // all dependencies are satisfied + log.debug(s"Rank $rank has all dependencies satisfied.") + promise.success(awaitingWorkers(toConnectSet)) + } else { + log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.") + // promise is pending fulfillment due to unresolved dependency + pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise)) + } + + sender() ! promise.future + + case WorkerStarted(_, started, awaitingAcceptance) => + startedWorkers.add(started) + if (awaitingAcceptance > 0) { + awaitConnWorkers.put(started, sender()) + } + + // remove the started rank from all dependencies. + dependencyMap.remove(started) + dependencyMap.foreach { case (r, dset) => + val updatedDependency = dset diff startedWorkers + // fulfill the future if all dependencies are met (started.) + if (updatedDependency.isEmpty) { + log.debug(s"Rank $r has all dependencies satisfied.") + pendingFulfillment.remove(r).map{ + case Fulfillment(toConnectSet, promise) => + promise.success(awaitingWorkers(toConnectSet)) + } + } + + dependencyMap.update(r, updatedDependency) + } + + case DropFromWaitingList(rank) => + assert(awaitConnWorkers.remove(rank).isDefined) + + case Terminated(ref) => + if (ref.equals(handler)) { + context.stop(self) + } + } +} + +private[scala] object RabitTrackerHandler { + // Messages sent by RabitTracker to this RabitTrackerHandler actor + trait TrackerControlMessage + case object RequestCompletionFuture extends TrackerControlMessage + case object RequestBoundFuture extends TrackerControlMessage + // Start the Rabit tracker at given socket address awaiting worker connections. + // All workers must connect to the tracker before connectionTimeout, otherwise the tracker will + // shut down due to timeout. + case class StartTracker(addr: InetSocketAddress, + maxPortTrials: Int, + connectionTimeout: Duration) extends TrackerControlMessage + // To interrupt the tracker handler due to uncaught exception thrown by the thread acting as + // driver for the distributed training. + case class InterruptTracker(e: Throwable) extends TrackerControlMessage + + def props(numWorkers: Int): Props = + Props(new RabitTrackerHandler(numWorkers)) +} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala new file mode 100644 index 000000000..31acfc1ce --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala @@ -0,0 +1,456 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.rabit.handler + +import java.nio.{ByteBuffer, ByteOrder} + +import akka.io.Tcp +import akka.actor._ +import akka.util.ByteString +import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers} + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ +import scala.util.Try + +/** + * Actor to handle socket communication from worker node. + * To handle fragmentation in received data, this class acts like a FSM + * (finite-state machine) to keep track of the internal states. + * + * @param host IP address of the remote worker + * @param worldSize number of total workers + * @param tracker the RabitTrackerHandler actor reference + */ +private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef, + connection: ActorRef) + extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct] + with ActorLogging with Stash { + + import RabitWorkerHandler._ + import RabitTrackerHelpers._ + + private[this] var rank: Int = 0 + private[this] var port: Int = 0 + + // indicate if the connection is transient (like "print" or "shutdown") + private[this] var transient: Boolean = false + private[this] var peerClosed: Boolean = false + + // number of workers pending acceptance of current worker + private[this] var awaitingAcceptance: Int = 0 + private[this] var neighboringWorkers = Set.empty[Int] + + // TODO: use a single memory allocation to host all buffers, + // including the transient ones for writing. + private[this] val readBuffer = ByteBuffer.allocate(4096) + .order(ByteOrder.nativeOrder()) + // in case the received message is longer than needed, + // stash the spilled over part in this buffer, and send + // to self when transition occurs. + private[this] val spillOverBuffer = ByteBuffer.allocate(4096) + .order(ByteOrder.nativeOrder()) + // when setup is complete, need to notify peer handlers + // to reduce the awaiting-connection counter. + private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None + + private def resetBuffers(): Unit = { + readBuffer.clear() + if (spillOverBuffer.position() > 0) { + spillOverBuffer.flip() + self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer)) + spillOverBuffer.clear() + } + } + + private def stashSpillOver(buf: ByteBuffer): Unit = { + if (buf.remaining() > 0) spillOverBuffer.put(buf) + } + + def getNeighboringWorkers: Set[Int] = neighboringWorkers + + def decodeCommand(buffer: ByteBuffer): TrackerCommand = { + val rank = buffer.getInt() + val worldSize = buffer.getInt() + val jobId = buffer.getString + + val command = buffer.getString + command match { + case "start" => WorkerStart(rank, worldSize, jobId) + case "shutdown" => + transient = true + WorkerShutdown(rank, worldSize, jobId) + case "recover" => + require(rank >= 0, "Invalid rank for recovering worker.") + WorkerRecover(rank, worldSize, jobId) + case "print" => + transient = true + WorkerTrackerPrint(rank, worldSize, jobId, buffer.getString) + } + } + + startWith(AwaitingHandshake, DataStruct()) + + when(AwaitingHandshake) { + case Event(Tcp.Received(magic), _) => + assert(magic.length == 4) + val purportedMagic = magic.asNativeOrderByteBuffer.getInt + assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host") + + // echo back the magic number + connection ! Tcp.Write(magic) + goto(AwaitingCommand) using StructTrackerCommand + } + + when(AwaitingCommand) { + case Event(Tcp.Received(bytes), validator) => + bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) } + if (validator.verify(readBuffer)) { + readBuffer.flip() + tracker ! decodeCommand(readBuffer) + stashSpillOver(readBuffer) + } + + stay + // when rank for a worker is assigned, send encoded rank information + // back to worker over Tcp socket. + case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) => + log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent") + + rank = assignedRank + // ranks from the ring + val ringRanks = List( + // ringPrev + if (ring._1 != -1 && ring._1 != rank) ring._1 else -1, + // ringNext + if (ring._2 != -1 && ring._2 != rank) ring._2 else -1 + ) + + // update the set of all linked workers to current worker. + neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet + + connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize))) + // to prevent reading before state transition + connection ! Tcp.SuspendReading + goto(BuildingLinkMap) using StructNodes + } + + when(BuildingLinkMap) { + case Event(Tcp.Received(bytes), validator) => + bytes.asByteBuffers.foreach { buf => + readBuffer.put(buf) + } + + if (validator.verify(readBuffer)) { + readBuffer.flip() + // for a freshly started worker, numConnected should be 0. + val numConnected = readBuffer.getInt() + val toConnectSet = neighboringWorkers.diff( + (0 until numConnected).map { index => readBuffer.getInt() }.toSet) + + // check which workers are currently awaiting connections + tracker ! RequestAwaitConnWorkers(rank, toConnectSet) + } + stay + + // got a Future from the tracker (resolver) about workers that are + // currently awaiting connections (particularly from this node.) + case Event(future: Future[_], _) => + // blocks execution until all dependencies for current worker is resolved. + Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match { + // numNotReachable is the number of workers that currently + // cannot be connected to (pending connection or setup). Instead, this worker will AWAIT + // connections from those currently non-reachable nodes in the future. + case AwaitingConnections(waitConnNodes, numNotReachable) => + log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable") + val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder()) + buf.putInt(waitConnNodes.size).putInt(numNotReachable) + buf.flip() + + // cache this message until the final state (SetupComplete) + pendingAcknowledgement = Some(AcknowledgeAcceptance( + waitConnNodes, numNotReachable)) + + connection ! Tcp.Write(ByteString.fromByteBuffer(buf)) + if (waitConnNodes.isEmpty) { + connection ! Tcp.SuspendReading + goto(AwaitingErrorCount) + } + else { + waitConnNodes.foreach { case (peerRank, peerRef) => + peerRef ! RequestWorkerHostPort + } + + // a countdown for DivulgedHostPort messages. + stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1) + } + } + + case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) => + val hostBytes = peerHost.getBytes() + val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length) + .order(ByteOrder.nativeOrder()) + buffer.putInt(peerHost.length).put(hostBytes) + .putInt(peerPort).putInt(peerRank) + + buffer.flip() + connection ! Tcp.Write(ByteString.fromByteBuffer(buffer)) + + if (data.counter == 0) { + // to prevent reading before state transition + connection ! Tcp.SuspendReading + goto(AwaitingErrorCount) + } + else { + stay using data.decrement() + } + } + + when(AwaitingErrorCount) { + case Event(Tcp.Received(numErrors), _) => + val buf = numErrors.asNativeOrderByteBuffer + + buf.getInt match { + case 0 => + stashSpillOver(buf) + goto(AwaitingPortNumber) + case _ => + stashSpillOver(buf) + goto(BuildingLinkMap) using StructNodes + } + } + + when(AwaitingPortNumber) { + case Event(Tcp.Received(assignedPort), _) => + assert(assignedPort.length == 4) + port = assignedPort.asNativeOrderByteBuffer.getInt + log.debug(s"Rank $rank listening @ $host:$port") + // wait until the worker closes connection. + if (peerClosed) goto(SetupComplete) else stay + + case Event(Tcp.PeerClosed, _) => + peerClosed = true + if (port == 0) stay else goto(SetupComplete) + } + + when(SetupComplete) { + case Event(ReduceWaitCount(count: Int), _) => + awaitingAcceptance -= count + // check peerClosed to avoid prematurely stopping this actor (which sends RST to worker) + if (awaitingAcceptance == 0 && peerClosed) { + tracker ! DropFromWaitingList(rank) + // no longer needed. + context.stop(self) + } + stay + + case Event(AcknowledgeAcceptance(peers, numBad), _) => + awaitingAcceptance = numBad + tracker ! WorkerStarted(host, rank, awaitingAcceptance) + peers.values.foreach { peer => + peer ! ReduceWaitCount(1) + } + + if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill + + stay + + // can only divulge the complete host and port information + // when this worker is declared fully connected (otherwise + // port information is still missing.) + case Event(RequestWorkerHostPort, _) => + sender() ! DivulgedWorkerHostPort(rank, host, port) + stay + } + + onTransition { + // reset buffer when state transitions as data becomes stale + case _ -> SetupComplete => + connection ! Tcp.ResumeReading + resetBuffers() + if (pendingAcknowledgement.isDefined) { + self ! pendingAcknowledgement.get + } + case _ => + connection ! Tcp.ResumeReading + resetBuffers() + } + + // default message handler + whenUnhandled { + case Event(Tcp.PeerClosed, _) => + peerClosed = true + if (transient) context.stop(self) + stay + } +} + +private[scala] object RabitWorkerHandler { + val MAGIC_NUMBER = 0xff99 + + // Finite states of this actor, which acts like a FSM. + // The following states are defined in order as the FSM progresses. + sealed trait State + + // [1] Initial state, awaiting worker to send magic number per protocol. + case object AwaitingHandshake extends State + // [2] Awaiting worker to send command (start/print/recover/shutdown etc.) + case object AwaitingCommand extends State + // [3] Brokers connections between workers per ring/tree/parent link map. + case object BuildingLinkMap extends State + // [4] A transient state in which the worker reports the number of errors in establishing + // connections to other peer workers. If no errors, transition to next state. + case object AwaitingErrorCount extends State + // [5] Awaiting the worker to report its port number for accepting connections from peer workers. + // This port number information is later forwarded to linked workers. + case object AwaitingPortNumber extends State + // [6] Final state after completing the setup with the connecting worker. At this stage, the + // worker will have closed the Tcp connection. The actor remains alive to handle messages from + // peer actors representing workers with pending setups. + case object SetupComplete extends State + + sealed trait DataField + case object IntField extends DataField + // an integer preceding the actual string + case object StringField extends DataField + case object IntSeqField extends DataField + + object DataStruct { + def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0) + } + + // Internal data pertaining to individual state, used to verify the validity of packets sent by + // workers. + case class DataStruct(fields: Seq[DataField], counter: Int) { + /** + * Validate whether the provided buffer is complete (i.e., contains + * all data fields specified for this DataStruct.) + * + * @param buf a byte buffer containing received data. + */ + def verify(buf: ByteBuffer): Boolean = { + if (fields.isEmpty) return true + + val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder()) + dupBuf.flip() + + Try(fields.foldLeft(true) { + case (complete, field) => + val remBytes = dupBuf.remaining() + complete && (remBytes > 0) && (remBytes >= (field match { + case IntField => + dupBuf.position(dupBuf.position() + 4) + 4 + case StringField => + val strLen = dupBuf.getInt + dupBuf.position(dupBuf.position() + strLen) + 4 + strLen + case IntSeqField => + val seqLen = dupBuf.getInt + dupBuf.position(dupBuf.position() + seqLen * 4) + 4 + seqLen * 4 + })) + }).getOrElse(false) + } + + def increment(): DataStruct = DataStruct(fields, counter + 1) + def decrement(): DataStruct = DataStruct(fields, counter - 1) + } + + val StructNodes = DataStruct(List(IntSeqField), 0) + val StructTrackerCommand = DataStruct(List( + IntField, IntField, StringField, StringField + ), 0) + + // ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ---- + + // RabitWorkerHandler --> RabitTrackerHandler + sealed trait RabitWorkerRequest + // RabitWorkerHandler <-- RabitTrackerHandler + sealed trait RabitWorkerResponse + + // Representations of decoded worker commands. + abstract class TrackerCommand(val command: String) extends RabitWorkerRequest { + def rank: Int + def worldSize: Int + def jobId: String + + def encode: ByteString = { + val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length) + .order(ByteOrder.nativeOrder()) + + buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes()) + .putInt(command.length).put(command.getBytes()).flip() + + ByteString.fromByteBuffer(buf) + } + } + + case class WorkerStart(rank: Int, worldSize: Int, jobId: String) + extends TrackerCommand("start") + case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String) + extends TrackerCommand("shutdown") + case class WorkerRecover(rank: Int, worldSize: Int, jobId: String) + extends TrackerCommand("recover") + case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String) + extends TrackerCommand("print") { + + override def encode: ByteString = { + val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length) + .order(ByteOrder.nativeOrder()) + + buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes()) + .putInt(command.length).put(command.getBytes()) + .putInt(msg.length).put(msg.getBytes()).flip() + + ByteString.fromByteBuffer(buf) + } + } + + // Request to remove the worker of given rank from the list of workers awaiting peer connections. + case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest + // Notify the tracker that the worker of given rank has finished setup and started. + case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int) + extends RabitWorkerRequest + // Request the set of workers to connect to, according to the LinkMap structure. + case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int]) + extends RabitWorkerRequest + + // Request, from the tracker, the set of nodes to connect. + case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int) + extends RabitWorkerResponse + + // ---- Messages between ConnectionHandler actors ---- + sealed trait IntraWorkerMessage + + // Notify neighboring workers to decrease the counter of awaiting workers by `count`. + case class ReduceWaitCount(count: Int) extends IntraWorkerMessage + // Request host and port information from peer ConnectionHandler actors (acting on behave of + // connecting workers.) This message will be brokered by RabitTrackerHandler. + case object RequestWorkerHostPort extends IntraWorkerMessage + // Response to the above request + case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage + // A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete". + case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int) + extends IntraWorkerMessage + + // ---- End of message definitions ---- + + def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = { + Props(new RabitWorkerHandler(host, worldSize, tracker, connection)) + } +} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala new file mode 100644 index 000000000..edec4931b --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala @@ -0,0 +1,136 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.rabit.util + +import java.nio.{ByteBuffer, ByteOrder} + +/** + * The assigned rank to a connecting Rabit worker, along with the information of the ranks of + * its linked peer workers, which are critical to perform Allreduce. + * When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker + * client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond + * with this class as a message, which is later encoded to byte string, and sent over socket + * connection to the worker client. + * + * @param rank assigned rank (ranked by worker connection order: first worker connecting to the + * tracker is assigned rank 0, second with rank 1, etc.) + * @param neighbors ranks of neighboring workers in a tree map. + * @param ring ranks of neighboring workers in a ring map. + * @param parent rank of the parent worker. + */ +private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int], + ring: (Int, Int), parent: Int) { + /** + * Encode the AssignedRank message into byte sequence for socket communication with Rabit worker + * client. + * @param worldSize the number of total distributed workers. Must match `numWorkers` used in + * LinkMap. + * @return a ByteBuffer containing encoded data. + */ + def toByteBuffer(worldSize: Int): ByteBuffer = { + val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder()) + buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length) + // neighbors in tree structure + neighbors.foreach { n => buffer.putInt(n) } + buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1) + buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1) + + buffer.flip() + buffer + } +} + +private[rabit] class LinkMap(numWorkers: Int) { + private def getNeighbors(rank: Int): Seq[Int] = { + val rank1 = rank + 1 + Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r => + r >= 0 && r < numWorkers + } + } + + /** + * Construct a ring structure that tends to share nodes with the tree. + * + * @param treeMap + * @param parentMap + * @param rank + * @return Seq[Int] instance starting from rank. + */ + private def constructShareRing(treeMap: Map[Int, Seq[Int]], + parentMap: Map[Int, Int], + rank: Int = 0): Seq[Int] = { + treeMap(rank).toSet - parentMap(rank) match { + case emptySet if emptySet.isEmpty => + List(rank) + case connectionSet => + connectionSet.zipWithIndex.foldLeft(List(rank)) { + case (ringSeq, (v, cnt)) => + val vConnSeq = constructShareRing(treeMap, parentMap, v) + vConnSeq match { + case vconn if vconn.size == cnt + 1 => + ringSeq ++ vconn.reverse + case vconn => + ringSeq ++ vconn + } + } + } + } + /** + * Construct a ring connection used to recover local data. + * + * @param treeMap + * @param parentMap + */ + private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = { + assert(parentMap(0) == -1) + + val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector + assert(sharedRing.length == treeMap.size) + + (0 until numWorkers).map { r => + val rPrev = (r + numWorkers - 1) % numWorkers + val rNext = (r + 1) % numWorkers + sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext)) + }.toMap + } + + private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap + private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap + private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_) + val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) { + case ((rmap, k), i) => + val kNext = ringMap_(k)._2 + (rmap ++ Map(kNext -> (i + 1)), kNext) + }._1 + + val ringMap = ringMap_.map { + case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1)) + } + val treeMap = treeMap_.map { + case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) } + } + val parentMap = parentMap_.map { + case (k, v) if k == 0 => + rMap_(k) -> -1 + case (k, v) => + rMap_(k) -> rMap_(v) + } + + def assignRank(rank: Int): AssignedRank = { + AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank)) + } +} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala new file mode 100644 index 000000000..3d7be618d --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala @@ -0,0 +1,39 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.rabit.util + +import java.nio.{ByteOrder, ByteBuffer} +import akka.util.ByteString + +private[rabit] object RabitTrackerHelpers { + implicit class ByteStringHelplers(bs: ByteString) { + // Java by default uses big endian. Enforce native endian so that + // the byte order is consistent with the workers. + def asNativeOrderByteBuffer: ByteBuffer = { + bs.asByteBuffer.order(ByteOrder.nativeOrder()) + } + } + + implicit class ByteBufferHelpers(buf: ByteBuffer) { + def getString: String = { + val len = buf.getInt() + val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder()) + buf.get(stringBuffer.array(), 0, len) + new String(stringBuffer.array(), "utf-8") + } + } +} diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index db4f93b44..0c4a85dcc 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -94,6 +94,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext( long max_elem = cbatch.offset[cbatch.size]; cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0); cbatch.value = jenv->GetFloatArrayElements(jvalue, 0); + CHECK_EQ(jenv->GetArrayLength(jindex), max_elem) << "batch.index.length must equal batch.offset.back()"; CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem) @@ -756,3 +757,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber jenv->SetIntArrayRegion(jout, 0, 1, &out); return 0; } + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: RabitAllreduce + * Signature: (Ljava/nio/ByteBuffer;III)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce + (JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) { + void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf); + RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL); + + return 0; +} diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 15410abed..8e42eea1c 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -303,6 +303,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber (JNIEnv *, jclass, jintArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: RabitAllreduce + * Signature: (Ljava/nio/ByteBuffer;III)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce + (JNIEnv *, jclass, jobject, jint, jint, jint); + #ifdef __cplusplus } #endif diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala new file mode 100644 index 000000000..ee4febe39 --- /dev/null +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala @@ -0,0 +1,224 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.rabit + +import java.nio.{ByteBuffer, ByteOrder} + +import akka.actor.{ActorRef, ActorSystem} +import akka.io.Tcp +import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe} +import akka.util.ByteString +import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler +import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._ +import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap +import org.junit.runner.RunWith +import org.scalatest.junit.JUnitRunner +import org.scalatest.{FlatSpecLike, Matchers} + +import scala.concurrent.Promise + +object RabitTrackerConnectionHandlerTest { + def intSeqToByteString(seq: Seq[Int]): ByteString = { + val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder()) + seq.foreach { i => buf.putInt(i) } + buf.flip() + ByteString.fromByteBuffer(buf) + } +} + +@RunWith(classOf[JUnitRunner]) +class RabitTrackerConnectionHandlerTest + extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest")) + with FlatSpecLike with Matchers with ImplicitSender { + + import RabitTrackerConnectionHandlerTest._ + + val magic = intSeqToByteString(List(0xff99)) + + "RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in { + val trackerProbe = TestProbe() + val connProbe = TestProbe() + + val worldSize = 4 + + val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize, + trackerProbe.ref, connProbe.ref)) + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake + + // send mock magic number + fsm ! Tcp.Received(magic) + connProbe.expectMsg(Tcp.Write(magic)) + + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand + fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand + // ResumeReading should be seen once state transitions + connProbe.expectMsg(Tcp.ResumeReading) + + // send mock tracker command in fragments: the handler should be able to handle it. + val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder()) + bufRank.putInt(0).putInt(worldSize).flip() + + val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder()) + bufJobId.putInt(1).put(Array[Byte]('0')).flip() + + val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder()) + bufCmd.putInt(5).put("start".getBytes()).flip() + + fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank)) + fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId)) + + // the state should not change for incomplete command data. + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand + + // send the last fragment, and expect message at tracker actor. + fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd)) + trackerProbe.expectMsg(WorkerStart(0, worldSize, "0")) + + val linkMap = new LinkMap(worldSize) + val assignedRank = linkMap.assignRank(0) + trackerProbe.reply(assignedRank) + + connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer( + assignedRank.toByteBuffer(worldSize) + ))) + + // reading should be suspended upon transitioning to BuildingLinkMap + connProbe.expectMsg(Tcp.SuspendReading) + // state should transition with according state data changes. + fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap + fsm.stateData shouldEqual RabitWorkerHandler.StructNodes + connProbe.expectMsg(Tcp.ResumeReading) + + // since the connection handler in test has rank 0, it will not have any nodes to connect to. + fsm ! Tcp.Received(intSeqToByteString(List(0))) + trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers)) + + // return mock response to the connection handler + val awaitConnPromise = Promise[AwaitingConnections]() + awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef], + fsm.underlyingActor.getNeighboringWorkers.size + )) + fsm ! awaitConnPromise.future + connProbe.expectMsg(Tcp.Write( + intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size)) + )) + connProbe.expectMsg(Tcp.SuspendReading) + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount + connProbe.expectMsg(Tcp.ResumeReading) + + // send mock error count (0) + fsm ! Tcp.Received(intSeqToByteString(List(0))) + + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber + connProbe.expectMsg(Tcp.ResumeReading) + + // simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events. + fsm ! Tcp.PeerClosed + // state should not transition + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber + fsm ! Tcp.Received(intSeqToByteString(List(32768))) + + fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete + connProbe.expectMsg(Tcp.ResumeReading) + + trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2)) + + val handlerStopProbe = TestProbe() + handlerStopProbe watch fsm + + // simulate connections from other workers by mocking ReduceWaitCount commands + fsm ! RabitWorkerHandler.ReduceWaitCount(1) + fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete + fsm ! RabitWorkerHandler.ReduceWaitCount(1) + trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0)) + handlerStopProbe.expectTerminated(fsm) + + // all done. + } + + it should "forward print command to tracker" in { + val trackerProbe = TestProbe() + val connProbe = TestProbe() + + val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4, + trackerProbe.ref, connProbe.ref)) + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake + + fsm ! Tcp.Received(magic) + connProbe.expectMsg(Tcp.Write(magic)) + + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand + fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand + // ResumeReading should be seen once state transitions + connProbe.expectMsg(Tcp.ResumeReading) + + val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!") + fsm ! Tcp.Received(printCmd.encode) + + trackerProbe.expectMsg(printCmd) + } + + it should "handle spill-over Tcp data correctly between state transition" in { + val trackerProbe = TestProbe() + val connProbe = TestProbe() + + val worldSize = 4 + + val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize, + trackerProbe.ref, connProbe.ref)) + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake + + // send mock magic number + fsm ! Tcp.Received(magic) + connProbe.expectMsg(Tcp.Write(magic)) + + fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand + fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand + // ResumeReading should be seen once state transitions + connProbe.expectMsg(Tcp.ResumeReading) + + // send mock tracker command in fragments: the handler should be able to handle it. + val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder()) + bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0')) + .putInt(5).put("start".getBytes()) + // spilled-over data + .putInt(0).flip() + + // send data with 4 extra bytes corresponding to the next state. + fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd)) + + trackerProbe.expectMsg(WorkerStart(0, worldSize, "0")) + + val linkMap = new LinkMap(worldSize) + val assignedRank = linkMap.assignRank(0) + trackerProbe.reply(assignedRank) + + connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer( + assignedRank.toByteBuffer(worldSize) + ))) + + // reading should be suspended upon transitioning to BuildingLinkMap + connProbe.expectMsg(Tcp.SuspendReading) + // state should transition with according state data changes. + fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap + fsm.stateData shouldEqual RabitWorkerHandler.StructNodes + connProbe.expectMsg(Tcp.ResumeReading) + + // the handler should be able to handle spill-over data, and stash it until state transition. + trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers)) + } +}