From 564df59204f75a305797fa0dbf70a635c797fd47 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 20 Apr 2023 16:29:35 +0800 Subject: [PATCH] [breaking] [jvm-packages] Remove scala-implemented tracker. (#9045) --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 20 +- .../spark/CommunicatorRobustnessSuite.scala | 121 +---- .../xgboost4j/scala/rabit/RabitTracker.scala | 195 -------- .../rabit/handler/RabitTrackerHandler.scala | 361 -------------- .../rabit/handler/RabitWorkerHandler.scala | 467 ------------------ .../xgboost4j/scala/rabit/util/LinkMap.scala | 136 ----- .../rabit/util/RabitTrackerHelpers.scala | 39 -- .../RabitTrackerConnectionHandlerTest.scala | 255 ---------- 8 files changed, 9 insertions(+), 1585 deletions(-) delete mode 100644 jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala delete mode 100644 jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala delete mode 100644 jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala delete mode 100644 jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala delete mode 100644 jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala delete mode 100644 jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 281997295..0aeae791a 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ import scala.util.Random import scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker} -import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} @@ -44,21 +43,16 @@ import org.apache.spark.sql.SparkSession * Use a finite, non-zero timeout value to prevent tracker from * hanging indefinitely (in milliseconds) * (supported by "scala" implementation only.) - * @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of - * the Python Rabit tracker (in dmlc_core), whereas the latter is implemented - * in Scala without Python components, and with full support of timeouts. - * The Scala implementation is currently experimental, use at your own risk. - * * @param hostIp The Rabit Tracker host IP address which is only used for python implementation. * This is only needed if the host IP cannot be automatically guessed. * @param pythonExec The python executed path for Rabit Tracker, * which is only used for python implementation. */ -case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String, +case class TrackerConf(workerConnectionTimeout: Long, hostIp: String = "", pythonExec: String = "") object TrackerConf { - def apply(): TrackerConf = TrackerConf(0L, "python") + def apply(): TrackerConf = TrackerConf(0L) } private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int, @@ -349,11 +343,9 @@ object XGBoost extends Serializable { /** visiable for testing */ private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { - val tracker: IRabitTracker = trackerConf.trackerImpl match { - case "scala" => new RabitTracker(nWorkers) - case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec) - case _ => new PyRabitTracker(nWorkers) - } + val tracker: IRabitTracker = new PyRabitTracker( + nWorkers, trackerConf.hostIp, trackerConf.pythonExec + ) tracker } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala index 579e3dd37..04081c3fe 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala @@ -22,7 +22,6 @@ import scala.util.Random import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus -import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker} import ml.dmlc.xgboost4j.scala.DMatrix import org.scalatest.FunSuite @@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest { val paramMap = Map( "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, "python", hostIp)) + "tracker_conf" -> TrackerConf(0L, hostIp)) val xgbExecParams = getXGBoostExecutionParams(paramMap) val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) tracker match { @@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest { val paramMap1 = Map( "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, "python", "", pythonExec)) + "tracker_conf" -> TrackerConf(0L, "", pythonExec)) val xgbExecParams1 = getXGBoostExecutionParams(paramMap1) val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf) tracker1 match { @@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest { val paramMap2 = Map( "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec)) + "tracker_conf" -> TrackerConf(0L, hostIp, pythonExec)) val xgbExecParams2 = getXGBoostExecutionParams(paramMap2) val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf) tracker2 match { @@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest { } } - test("training with Scala-implemented Rabit tracker") { - val eval = new EvalError() - val training = buildDataFrame(Classification.train) - val testDM = new DMatrix(Classification.test.iterator) - val paramMap = Map("eta" -> "1", "max_depth" -> "6", - "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")) - val model = new XGBoostClassifier(paramMap).fit(training) - assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) - } - - test("test Communicator allreduce to validate Scala-implemented Rabit tracker") { - val vectorLength = 100 - val rdd = sc.parallelize( - (1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache() - - val tracker = new ScalaRabitTracker(numWorkers) - tracker.start(0) - val trackerEnvs = tracker.getWorkerEnvs - val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]() - - val rawData = rdd.mapPartitions { iter => - Iterator(iter.toArray) - }.collect() - - val maxVec = (0 until vectorLength).toArray.map { j => - (0 until numWorkers).toArray.map { i => rawData(i)(j) }.max - } - - val allReduceResults = rdd.mapPartitions { iter => - Communicator.init(trackerEnvs) - val arr = iter.toArray - val results = Communicator.allReduce(arr, Communicator.OpType.MAX) - Communicator.shutdown() - Iterator(results) - }.cache() - - val sparkThread = new Thread() { - override def run(): Unit = { - allReduceResults.foreachPartition(() => _) - val byPartitionResults = allReduceResults.collect() - assert(byPartitionResults(0).length == vectorLength) - collectedAllReduceResults.put(byPartitionResults(0)) - } - } - sparkThread.start() - assert(tracker.waitFor(0L) == 0) - sparkThread.join() - - assert(collectedAllReduceResults.poll().sameElements(maxVec)) - } - test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") { /* Deliberately create new instances of SparkContext in each unit test to avoid reusing the @@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest { assert(tracker.waitFor(0) != 0) } - test("test Scala RabitTracker's exception handling: it should not hang forever.") { - val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() - - val tracker = new ScalaRabitTracker(numWorkers) - tracker.start(0) - val trackerEnvs = tracker.getWorkerEnvs - - val workerCount: Int = numWorkers - val dummyTasks = rdd.mapPartitions { iter => - Communicator.init(trackerEnvs) - val index = iter.next() - Thread.sleep(100 + index * 10) - if (index == workerCount) { - // kill the worker by throwing an exception - throw new RuntimeException("Worker exception.") - } - Communicator.shutdown() - Iterator(index) - }.cache() - - val sparkThread = new Thread() { - override def run(): Unit = { - // forces a Spark job. - dummyTasks.foreachPartition(() => _) - } - } - sparkThread.setUncaughtExceptionHandler(tracker) - sparkThread.start() - assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode) - } - - test("test Scala RabitTracker's workerConnectionTimeout") { - val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() - - val tracker = new ScalaRabitTracker(numWorkers) - tracker.start(500) - val trackerEnvs = tracker.getWorkerEnvs - - val dummyTasks = rdd.mapPartitions { iter => - val index = iter.next() - // simulate that the first worker cannot connect to tracker due to network issues. - if (index != 1) { - Communicator.init(trackerEnvs) - Thread.sleep(1000) - Communicator.shutdown() - } - - Iterator(index) - }.cache() - - val sparkThread = new Thread() { - override def run(): Unit = { - // forces a Spark job. - dummyTasks.foreachPartition(() => _) - } - } - sparkThread.setUncaughtExceptionHandler(tracker) - sparkThread.start() - // should fail due to connection timeout - assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode) - } - test("should allow the dataframe containing communicator calls to be partially evaluated for" + " multiple times (ISSUE-4406)") { val paramMap = Map( diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala deleted file mode 100644 index fb388d083..000000000 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.rabit - -import java.net.{InetAddress, InetSocketAddress} - -import akka.actor.ActorSystem -import akka.pattern.ask -import ml.dmlc.xgboost4j.java.{IRabitTracker, TrackerProperties} -import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler - -import scala.concurrent.duration._ -import scala.concurrent.{Await, Future} -import scala.util.{Failure, Success, Try} - -/** - * Scala implementation of the Rabit tracker interface without Python dependency. - * The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker - * (and thus any distributed tasks) to hang indefinitely due to network issues or worker node - * failures. - * - * Note that this implementation is currently experimental, and should be used at your own risk. - * - * Example usage: - * {{{ - * import scala.concurrent.duration._ - * - * val tracker = new RabitTracker(32) - * // allow up to 10 minutes for all workers to connect to the tracker. - * tracker.start(10 minutes) - * - * /* ... - * launching workers in parallel - * ... - * */ - * - * // wait for worker execution up to 6 hours. - * // providing a finite timeout prevents a long-running task from hanging forever in - * // catastrophic events, like the loss of an executor during model training. - * tracker.waitFor(6 hours) - * }}} - * - * @param numWorkers Number of distributed workers from which the tracker expects connections. - * @param port The minimum port number that the tracker binds to. - * If port is omitted, or given as None, a random ephemeral port is chosen at runtime. - * @param maxPortTrials The maximum number of trials of socket binding, by sequentially - * increasing the port number. - */ -private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None, - maxPortTrials: Int = 1000) - extends IRabitTracker { - - import scala.collection.JavaConverters._ - - require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).") - - val system = ActorSystem.create("RabitTracker") - val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler") - implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds) - private[this] val tcpBindingTimeout: Duration = 1 minute - - var workerEnvs: Map[String, String] = Map.empty - - override def uncaughtException(t: Thread, e: Throwable): Unit = { - handler ? RabitTrackerHandler.InterruptTracker(e) - } - - /** - * Start the Rabit tracker. - * - * @param timeout The timeout for awaiting connections from worker nodes. - * Note that when used in Spark applications, because all Spark transformations are - * lazily executed, the I/O time for loading RDDs/DataFrames from external sources - * (local dist, HDFS, S3 etc.) must be taken into account for the timeout value. - * If the timeout value is too small, the Rabit tracker will likely timeout before workers - * establishing connections to the tracker, due to the overhead of loading data. - * Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver - * running it) from hanging indefinitely due to worker connection issues (e.g. firewall.) - * @return Boolean flag indicating if the Rabit tracker starts successfully. - */ - private def start(timeout: Duration): Boolean = { - val hostAddress = Option(TrackerProperties.getInstance().getHostIp) - .map(InetAddress.getByName).getOrElse(InetAddress.getLocalHost) - - handler ? RabitTrackerHandler.StartTracker( - new InetSocketAddress(hostAddress, port.getOrElse(0)), maxPortTrials, timeout) - - // block by waiting for the actor to bind to a port - Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration) - .asInstanceOf[Future[Map[String, String]]]) match { - case Success(futurePortBound) => - // The success of the Future is contingent on binding to an InetSocketAddress. - val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess - if (isBound) { - workerEnvs = Await.result(futurePortBound, 0 nano) - } - isBound - case Failure(ex: Throwable) => - false - } - } - - /** - * Start the Rabit tracker. - * - * @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker - * connections. If a non-positive value is provided, the tracker - * waits for incoming worker connections indefinitely. - * @return Boolean flag indicating if the Rabit tracker starts successfully. - */ - def start(connectionTimeoutMillis: Long): Boolean = { - if (connectionTimeoutMillis <= 0) { - start(Duration.Inf) - } else { - start(Duration.fromNanos(connectionTimeoutMillis * 1e6)) - } - } - - def stop(): Unit = { - system.terminate() - } - - /** - * Get a Map of necessary environment variables to initiate Rabit workers. - * - * @return HashMap containing tracker information. - */ - def getWorkerEnvs: java.util.Map[String, String] = { - new java.util.HashMap((workerEnvs ++ Map( - "DMLC_NUM_WORKER" -> numWorkers.toString, - "DMLC_NUM_SERVER" -> "0" - )).asJava) - } - - /** - * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds. - * This method blocks until timeout or task completion. - * - * @param atMost the maximum execution time for the workers. By default, - * the tracker waits for the workers indefinitely. - * @return 0 if the tasks complete successfully, and non-zero otherwise. - */ - private def waitFor(atMost: Duration): Int = { - // request the completion Future from the tracker actor - Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration) - .asInstanceOf[Future[Int]]) match { - case Success(futureCompleted) => - // wait for all workers to complete synchronously. - val statusCode = Try(Await.result(futureCompleted, atMost)) match { - case Success(n) if n == numWorkers => - IRabitTracker.TrackerStatus.SUCCESS.getStatusCode - case Success(n) if n < numWorkers => - IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode - case Failure(e) => - IRabitTracker.TrackerStatus.FAILURE.getStatusCode - } - system.terminate() - statusCode - case Failure(ex: Throwable) => - system.terminate() - IRabitTracker.TrackerStatus.FAILURE.getStatusCode - } - } - - /** - * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds. - * This method blocks until timeout or task completion. - * - * @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a - * non-positive number is given, the tracker waits indefinitely. - * @return 0 if the tasks complete successfully, and non-zero otherwise - */ - def waitFor(atMostMillis: Long): Int = { - if (atMostMillis <= 0) { - waitFor(Duration.Inf) - } else { - waitFor(Duration.fromNanos(atMostMillis * 1e6)) - } - } -} - diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala deleted file mode 100644 index f9de71746..000000000 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala +++ /dev/null @@ -1,361 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.rabit.handler - -import java.net.InetSocketAddress -import java.util.UUID - -import scala.concurrent.duration._ -import scala.collection.mutable -import scala.concurrent.{Promise, TimeoutException} -import akka.io.{IO, Tcp} -import akka.actor._ -import ml.dmlc.xgboost4j.java.XGBoostError -import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap} - -import scala.util.{Failure, Random, Success, Try} - -/** The Akka actor for handling and coordinating Rabit worker connections. - * This is the main actor for handling socket connections, interacting with the synchronous - * tracker interface, and resolving tree/ring/parent dependencies between workers. - * - * @param numWorkers Number of workers to track. - */ -private[scala] class RabitTrackerHandler(numWorkers: Int) - extends Actor with ActorLogging { - - import context.system - import RabitWorkerHandler._ - import RabitTrackerHandler._ - - private[this] val promisedWorkerEnvs = Promise[Map[String, String]]() - private[this] val promisedShutdownWorkers = Promise[Int]() - private[this] val tcpManager = IO(Tcp) - - // resolves worker connection dependency. - val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver") - - // workers that have sent "shutdown" signal - private[this] val shutdownWorkers = mutable.Set.empty[Int] - private[this] val jobToRankMap = mutable.HashMap.empty[String, Int] - private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String] - private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*) - private[this] var maxPortTrials = 0 - private[this] var workerConnectionTimeout: Duration = Duration.Inf - private[this] var portTrials = 0 - private[this] val startedWorkers = mutable.Set.empty[Int] - - val linkMap = new LinkMap(numWorkers) - - def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = { - rank match { - case r if r >= 0 => Some(r) - case _ => - jobId match { - case "NULL" => None - case jid => jobToRankMap.get(jid) - } - } - } - - /** - * Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled - * by the RabitWorkerHandler. - * - * @param event Generic Tcp.Event - */ - private def handleTcpEvents(event: Tcp.Event): Unit = event match { - case Tcp.Bound(local) => - // expect all workers to connect within timeout - log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}") - log.info(s"Worker connection timeout is $workerConnectionTimeout.") - - context.setReceiveTimeout(workerConnectionTimeout) - promisedWorkerEnvs.success(Map( - "DMLC_TRACKER_URI" -> local.getAddress.getHostAddress, - "DMLC_TRACKER_PORT" -> local.getPort.toString, - // not required because the world size will be communicated to the - // worker node after the rank is assigned. - "rabit_world_size" -> numWorkers.toString - )) - - case Tcp.CommandFailed(cmd: Tcp.Bind) => - if (portTrials < maxPortTrials) { - portTrials += 1 - tcpManager ! Tcp.Bind(self, - new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1), - backlog = 256) - } - - case Tcp.Connected(remote, local) => - log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}") - // revoke timeout if all workers have connected. - val workerHandler = context.actorOf(RabitWorkerHandler.props( - remote.getAddress.getHostAddress, numWorkers, self, sender() - ), s"ConnectionHandler-${UUID.randomUUID().toString}") - val connection = sender() - connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true) - - actorRefToHost.put(workerHandler, remote.getAddress.getHostName) - } - - /** - * Handles external tracker control messages sent by RabitTracker (usually in ask patterns) - * to interact with the tracker interface. - * - * @param trackerMsg control messages sent by RabitTracker class. - */ - private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit = - trackerMsg match { - - case msg: StartTracker => - maxPortTrials = msg.maxPortTrials - workerConnectionTimeout = msg.connectionTimeout - - // if the port number is missing, try binding to a random ephemeral port. - if (msg.addr.getPort == 0) { - tcpManager ! Tcp.Bind(self, - new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768), - backlog = 256) - } else { - tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256) - } - sender() ! true - - case RequestBoundFuture => - sender() ! promisedWorkerEnvs.future - - case RequestCompletionFuture => - sender() ! promisedShutdownWorkers.future - - case InterruptTracker(e) => - log.error(e, "Uncaught exception thrown by worker.") - // make sure that waitFor() does not hang indefinitely. - promisedShutdownWorkers.failure(e) - context.stop(self) - } - - /** - * Handles messages sent by child actors representing connecting Rabit workers, by brokering - * messages to the dependency resolver, and processing worker commands. - * - * @param workerMsg Message sent by RabitWorkerHandler actors. - */ - private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match { - case req @ RequestAwaitConnWorkers(_, _) => - // since the requester may request to connect to other workers - // that have not fully set up, delegate this request to the - // dependency resolver which handles the dependencies properly. - resolver forward req - - // ---- Rabit worker commands: start/recover/shutdown/print ---- - case WorkerTrackerPrint(_, _, _, msg) => - log.info(msg.trim) - - case WorkerShutdown(rank, _, _) => - assert(rank >= 0, "Invalid rank.") - assert(!shutdownWorkers.contains(rank)) - shutdownWorkers.add(rank) - - log.info(s"Received shutdown signal from $rank") - - if (shutdownWorkers.size == numWorkers) { - promisedShutdownWorkers.success(shutdownWorkers.size) - } - - case WorkerRecover(prevRank, worldSize, jobId) => - assert(prevRank >= 0) - sender() ! linkMap.assignRank(prevRank) - - case WorkerStart(rank, worldSize, jobId) => - assert(worldSize == numWorkers || worldSize == -1, - s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)." - ) - - Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match { - case Success(wkRank) => - if (jobId != "NULL") { - jobToRankMap.put(jobId, wkRank) - } - - val assignedRank = linkMap.assignRank(wkRank) - sender() ! assignedRank - resolver ! assignedRank - - log.info("Received start signal from " + - s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]") - - case Failure(ex: IndexOutOfBoundsException) => - // More than worldSize workers have connected, likely due to executor loss. - // Since Rabit currently does not support crash recovery (because the Allreduce results - // are not cached by the tracker, and because existing workers cannot reestablish - // connections to newly spawned executor/worker), the most reasonble action here is to - // interrupt the tracker immediate with failure state. - log.error("Received invalid start signal from " + - s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started." - ) - promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" + - " received from worker, likely due to executor loss.")) - - case Failure(ex) => - log.error(ex, "Unexpected error") - promisedShutdownWorkers.failure(ex) - } - - - // ---- Dependency resolving related messages ---- - case msg @ WorkerStarted(host, rank, awaitingAcceptance) => - log.info(s"Worker $host (rank: $rank) has started.") - resolver forward msg - - startedWorkers.add(rank) - if (startedWorkers.size == numWorkers) { - log.info("All workers have started.") - } - - case req @ DropFromWaitingList(_) => - // all peer workers in dependency link map have connected; - // forward message to resolver to update dependencies. - resolver forward req - - case _ => - } - - def receive: Actor.Receive = { - case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent) - case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg) - case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg) - - case akka.actor.ReceiveTimeout => - if (startedWorkers.size < numWorkers) { - promisedShutdownWorkers.failure( - new TimeoutException("Timed out waiting for workers to connect: " + - s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.") - ) - context.stop(self) - } - - context.setReceiveTimeout(Duration.Undefined) - } -} - -/** - * Resolve the dependency between nodes as they connect to the tracker. - * The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap - * and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of - * connections by workers, this dependency constraint assumes that a worker node connects first - * is likely to finish setup first. - */ -private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging { - import RabitWorkerHandler._ - - context.watch(handler) - - case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections]) - - // worker nodes that have connected, but have not send WorkerStarted message. - private val dependencyMap = mutable.Map.empty[Int, Set[Int]] - private val startedWorkers = mutable.Set.empty[Int] - // worker nodes that have started, and await for connections. - private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef] - private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment] - - def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = { - val connSet = awaitConnWorkers.toMap - .filterKeys(k => linkSet.contains(k)) - AwaitingConnections(connSet, linkSet.size - connSet.size) - } - - def receive: Actor.Receive = { - // a copy of the AssignedRank message that is also sent to the worker - case AssignedRank(rank, tree_neighbors, ring, parent) => - // the workers that the worker of given `rank` depends on: - // worker of rank K only depends on workers with rank smaller than K. - val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2)) - .filter{ r => r != -1 && r < rank} - - log.debug(s"Rank $rank connected, dependencies: $dependentWorkers") - dependencyMap.put(rank, dependentWorkers) - - case RequestAwaitConnWorkers(rank, toConnectSet) => - val promise = Promise[AwaitingConnections]() - - assert(dependencyMap.contains(rank)) - - val updatedDependency = dependencyMap(rank) diff startedWorkers - if (updatedDependency.isEmpty) { - // all dependencies are satisfied - log.debug(s"Rank $rank has all dependencies satisfied.") - promise.success(awaitingWorkers(toConnectSet)) - } else { - log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.") - // promise is pending fulfillment due to unresolved dependency - pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise)) - } - - sender() ! promise.future - - case WorkerStarted(_, started, awaitingAcceptance) => - startedWorkers.add(started) - if (awaitingAcceptance > 0) { - awaitConnWorkers.put(started, sender()) - } - - // remove the started rank from all dependencies. - dependencyMap.remove(started) - dependencyMap.foreach { case (r, dset) => - val updatedDependency = dset diff startedWorkers - // fulfill the future if all dependencies are met (started.) - if (updatedDependency.isEmpty) { - log.debug(s"Rank $r has all dependencies satisfied.") - pendingFulfillment.remove(r).map{ - case Fulfillment(toConnectSet, promise) => - promise.success(awaitingWorkers(toConnectSet)) - } - } - - dependencyMap.update(r, updatedDependency) - } - - case DropFromWaitingList(rank) => - assert(awaitConnWorkers.remove(rank).isDefined) - - case Terminated(ref) => - if (ref.equals(handler)) { - context.stop(self) - } - } -} - -private[scala] object RabitTrackerHandler { - // Messages sent by RabitTracker to this RabitTrackerHandler actor - trait TrackerControlMessage - case object RequestCompletionFuture extends TrackerControlMessage - case object RequestBoundFuture extends TrackerControlMessage - // Start the Rabit tracker at given socket address awaiting worker connections. - // All workers must connect to the tracker before connectionTimeout, otherwise the tracker will - // shut down due to timeout. - case class StartTracker(addr: InetSocketAddress, - maxPortTrials: Int, - connectionTimeout: Duration) extends TrackerControlMessage - // To interrupt the tracker handler due to uncaught exception thrown by the thread acting as - // driver for the distributed training. - case class InterruptTracker(e: Throwable) extends TrackerControlMessage - - def props(numWorkers: Int): Props = - Props(new RabitTrackerHandler(numWorkers)) -} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala deleted file mode 100644 index 234c4d25a..000000000 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala +++ /dev/null @@ -1,467 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.rabit.handler - -import java.nio.{ByteBuffer, ByteOrder} - -import akka.io.Tcp -import akka.actor._ -import akka.util.ByteString -import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers} - -import scala.concurrent.{Await, Future} -import scala.concurrent.duration._ -import scala.util.Try - -/** - * Actor to handle socket communication from worker node. - * To handle fragmentation in received data, this class acts like a FSM - * (finite-state machine) to keep track of the internal states. - * - * @param host IP address of the remote worker - * @param worldSize number of total workers - * @param tracker the RabitTrackerHandler actor reference - */ -private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef, - connection: ActorRef) - extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct] - with ActorLogging with Stash { - - import RabitWorkerHandler._ - import RabitTrackerHelpers._ - - private[this] var rank: Int = 0 - private[this] var port: Int = 0 - - // indicate if the connection is transient (like "print" or "shutdown") - private[this] var transient: Boolean = false - private[this] var peerClosed: Boolean = false - - // number of workers pending acceptance of current worker - private[this] var awaitingAcceptance: Int = 0 - private[this] var neighboringWorkers = Set.empty[Int] - - // TODO: use a single memory allocation to host all buffers, - // including the transient ones for writing. - private[this] val readBuffer = ByteBuffer.allocate(4096) - .order(ByteOrder.nativeOrder()) - // in case the received message is longer than needed, - // stash the spilled over part in this buffer, and send - // to self when transition occurs. - private[this] val spillOverBuffer = ByteBuffer.allocate(4096) - .order(ByteOrder.nativeOrder()) - // when setup is complete, need to notify peer handlers - // to reduce the awaiting-connection counter. - private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None - - private def resetBuffers(): Unit = { - readBuffer.clear() - if (spillOverBuffer.position() > 0) { - spillOverBuffer.flip() - self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer)) - spillOverBuffer.clear() - } - } - - private def stashSpillOver(buf: ByteBuffer): Unit = { - if (buf.remaining() > 0) spillOverBuffer.put(buf) - } - - def getNeighboringWorkers: Set[Int] = neighboringWorkers - - def decodeCommand(buffer: ByteBuffer): TrackerCommand = { - val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) - readBuffer.flip() - - val rank = readBuffer.getInt() - val worldSize = readBuffer.getInt() - val jobId = readBuffer.getString - - val command = readBuffer.getString - val trackerCommand = command match { - case "start" => WorkerStart(rank, worldSize, jobId) - case "shutdown" => - transient = true - WorkerShutdown(rank, worldSize, jobId) - case "recover" => - require(rank >= 0, "Invalid rank for recovering worker.") - WorkerRecover(rank, worldSize, jobId) - case "print" => - transient = true - WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString) - } - - stashSpillOver(readBuffer) - trackerCommand - } - - startWith(AwaitingHandshake, DataStruct()) - - when(AwaitingHandshake) { - case Event(Tcp.Received(magic), _) => - assert(magic.length == 4) - val purportedMagic = magic.asNativeOrderByteBuffer.getInt - assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host") - - // echo back the magic number - connection ! Tcp.Write(magic) - goto(AwaitingCommand) using StructTrackerCommand - } - - when(AwaitingCommand) { - case Event(Tcp.Received(bytes), validator) => - bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) } - if (validator.verify(readBuffer)) { - Try(decodeCommand(readBuffer)) match { - case scala.util.Success(decodedCommand) => - tracker ! decodedCommand - case scala.util.Failure(th: java.nio.BufferUnderflowException) => - // BufferUnderflowException would occur if the message to print has not arrived yet. - // Do nothing, wait for next Tcp.Received event - case scala.util.Failure(th: Throwable) => throw th - } - } - - stay - // when rank for a worker is assigned, send encoded rank information - // back to worker over Tcp socket. - case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) => - log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent") - - rank = assignedRank - // ranks from the ring - val ringRanks = List( - // ringPrev - if (ring._1 != -1 && ring._1 != rank) ring._1 else -1, - // ringNext - if (ring._2 != -1 && ring._2 != rank) ring._2 else -1 - ) - - // update the set of all linked workers to current worker. - neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet - - connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize))) - // to prevent reading before state transition - connection ! Tcp.SuspendReading - goto(BuildingLinkMap) using StructNodes - } - - when(BuildingLinkMap) { - case Event(Tcp.Received(bytes), validator) => - bytes.asByteBuffers.foreach { buf => - readBuffer.put(buf) - } - - if (validator.verify(readBuffer)) { - readBuffer.flip() - // for a freshly started worker, numConnected should be 0. - val numConnected = readBuffer.getInt() - val toConnectSet = neighboringWorkers.diff( - (0 until numConnected).map { index => readBuffer.getInt() }.toSet) - - // check which workers are currently awaiting connections - tracker ! RequestAwaitConnWorkers(rank, toConnectSet) - } - stay - - // got a Future from the tracker (resolver) about workers that are - // currently awaiting connections (particularly from this node.) - case Event(future: Future[_], _) => - // blocks execution until all dependencies for current worker is resolved. - Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match { - // numNotReachable is the number of workers that currently - // cannot be connected to (pending connection or setup). Instead, this worker will AWAIT - // connections from those currently non-reachable nodes in the future. - case AwaitingConnections(waitConnNodes, numNotReachable) => - log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable") - val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder()) - buf.putInt(waitConnNodes.size).putInt(numNotReachable) - buf.flip() - - // cache this message until the final state (SetupComplete) - pendingAcknowledgement = Some(AcknowledgeAcceptance( - waitConnNodes, numNotReachable)) - - connection ! Tcp.Write(ByteString.fromByteBuffer(buf)) - if (waitConnNodes.isEmpty) { - connection ! Tcp.SuspendReading - goto(AwaitingErrorCount) - } - else { - waitConnNodes.foreach { case (peerRank, peerRef) => - peerRef ! RequestWorkerHostPort - } - - // a countdown for DivulgedHostPort messages. - stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1) - } - } - - case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) => - val hostBytes = peerHost.getBytes() - val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length) - .order(ByteOrder.nativeOrder()) - buffer.putInt(peerHost.length).put(hostBytes) - .putInt(peerPort).putInt(peerRank) - - buffer.flip() - connection ! Tcp.Write(ByteString.fromByteBuffer(buffer)) - - if (data.counter == 0) { - // to prevent reading before state transition - connection ! Tcp.SuspendReading - goto(AwaitingErrorCount) - } - else { - stay using data.decrement() - } - } - - when(AwaitingErrorCount) { - case Event(Tcp.Received(numErrors), _) => - val buf = numErrors.asNativeOrderByteBuffer - - buf.getInt match { - case 0 => - stashSpillOver(buf) - goto(AwaitingPortNumber) - case _ => - stashSpillOver(buf) - goto(BuildingLinkMap) using StructNodes - } - } - - when(AwaitingPortNumber) { - case Event(Tcp.Received(assignedPort), _) => - assert(assignedPort.length == 4) - port = assignedPort.asNativeOrderByteBuffer.getInt - log.debug(s"Rank $rank listening @ $host:$port") - // wait until the worker closes connection. - if (peerClosed) goto(SetupComplete) else stay - - case Event(Tcp.PeerClosed, _) => - peerClosed = true - if (port == 0) stay else goto(SetupComplete) - } - - when(SetupComplete) { - case Event(ReduceWaitCount(count: Int), _) => - awaitingAcceptance -= count - // check peerClosed to avoid prematurely stopping this actor (which sends RST to worker) - if (awaitingAcceptance == 0 && peerClosed) { - tracker ! DropFromWaitingList(rank) - // no longer needed. - context.stop(self) - } - stay - - case Event(AcknowledgeAcceptance(peers, numBad), _) => - awaitingAcceptance = numBad - tracker ! WorkerStarted(host, rank, awaitingAcceptance) - peers.values.foreach { peer => - peer ! ReduceWaitCount(1) - } - - if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill - - stay - - // can only divulge the complete host and port information - // when this worker is declared fully connected (otherwise - // port information is still missing.) - case Event(RequestWorkerHostPort, _) => - sender() ! DivulgedWorkerHostPort(rank, host, port) - stay - } - - onTransition { - // reset buffer when state transitions as data becomes stale - case _ -> SetupComplete => - connection ! Tcp.ResumeReading - resetBuffers() - if (pendingAcknowledgement.isDefined) { - self ! pendingAcknowledgement.get - } - case _ => - connection ! Tcp.ResumeReading - resetBuffers() - } - - // default message handler - whenUnhandled { - case Event(Tcp.PeerClosed, _) => - peerClosed = true - if (transient) context.stop(self) - stay - } -} - -private[scala] object RabitWorkerHandler { - val MAGIC_NUMBER = 0xff99 - - // Finite states of this actor, which acts like a FSM. - // The following states are defined in order as the FSM progresses. - sealed trait State - - // [1] Initial state, awaiting worker to send magic number per protocol. - case object AwaitingHandshake extends State - // [2] Awaiting worker to send command (start/print/recover/shutdown etc.) - case object AwaitingCommand extends State - // [3] Brokers connections between workers per ring/tree/parent link map. - case object BuildingLinkMap extends State - // [4] A transient state in which the worker reports the number of errors in establishing - // connections to other peer workers. If no errors, transition to next state. - case object AwaitingErrorCount extends State - // [5] Awaiting the worker to report its port number for accepting connections from peer workers. - // This port number information is later forwarded to linked workers. - case object AwaitingPortNumber extends State - // [6] Final state after completing the setup with the connecting worker. At this stage, the - // worker will have closed the Tcp connection. The actor remains alive to handle messages from - // peer actors representing workers with pending setups. - case object SetupComplete extends State - - sealed trait DataField - case object IntField extends DataField - // an integer preceding the actual string - case object StringField extends DataField - case object IntSeqField extends DataField - - object DataStruct { - def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0) - } - - // Internal data pertaining to individual state, used to verify the validity of packets sent by - // workers. - case class DataStruct(fields: Seq[DataField], counter: Int) { - /** - * Validate whether the provided buffer is complete (i.e., contains - * all data fields specified for this DataStruct.) - * - * @param buf a byte buffer containing received data. - */ - def verify(buf: ByteBuffer): Boolean = { - if (fields.isEmpty) return true - - val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder()) - dupBuf.flip() - - Try(fields.foldLeft(true) { - case (complete, field) => - val remBytes = dupBuf.remaining() - complete && (remBytes > 0) && (remBytes >= (field match { - case IntField => - dupBuf.position(dupBuf.position() + 4) - 4 - case StringField => - val strLen = dupBuf.getInt - dupBuf.position(dupBuf.position() + strLen) - 4 + strLen - case IntSeqField => - val seqLen = dupBuf.getInt - dupBuf.position(dupBuf.position() + seqLen * 4) - 4 + seqLen * 4 - })) - }).getOrElse(false) - } - - def increment(): DataStruct = DataStruct(fields, counter + 1) - def decrement(): DataStruct = DataStruct(fields, counter - 1) - } - - val StructNodes = DataStruct(List(IntSeqField), 0) - val StructTrackerCommand = DataStruct(List( - IntField, IntField, StringField, StringField - ), 0) - - // ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ---- - - // RabitWorkerHandler --> RabitTrackerHandler - sealed trait RabitWorkerRequest - // RabitWorkerHandler <-- RabitTrackerHandler - sealed trait RabitWorkerResponse - - // Representations of decoded worker commands. - abstract class TrackerCommand(val command: String) extends RabitWorkerRequest { - def rank: Int - def worldSize: Int - def jobId: String - - def encode: ByteString = { - val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length) - .order(ByteOrder.nativeOrder()) - - buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes()) - .putInt(command.length).put(command.getBytes()).flip() - - ByteString.fromByteBuffer(buf) - } - } - - case class WorkerStart(rank: Int, worldSize: Int, jobId: String) - extends TrackerCommand("start") - case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String) - extends TrackerCommand("shutdown") - case class WorkerRecover(rank: Int, worldSize: Int, jobId: String) - extends TrackerCommand("recover") - case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String) - extends TrackerCommand("print") { - - override def encode: ByteString = { - val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length) - .order(ByteOrder.nativeOrder()) - - buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes()) - .putInt(command.length).put(command.getBytes()) - .putInt(msg.length).put(msg.getBytes()).flip() - - ByteString.fromByteBuffer(buf) - } - } - - // Request to remove the worker of given rank from the list of workers awaiting peer connections. - case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest - // Notify the tracker that the worker of given rank has finished setup and started. - case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int) - extends RabitWorkerRequest - // Request the set of workers to connect to, according to the LinkMap structure. - case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int]) - extends RabitWorkerRequest - - // Request, from the tracker, the set of nodes to connect. - case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int) - extends RabitWorkerResponse - - // ---- Messages between ConnectionHandler actors ---- - sealed trait IntraWorkerMessage - - // Notify neighboring workers to decrease the counter of awaiting workers by `count`. - case class ReduceWaitCount(count: Int) extends IntraWorkerMessage - // Request host and port information from peer ConnectionHandler actors (acting on behave of - // connecting workers.) This message will be brokered by RabitTrackerHandler. - case object RequestWorkerHostPort extends IntraWorkerMessage - // Response to the above request - case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage - // A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete". - case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int) - extends IntraWorkerMessage - - // ---- End of message definitions ---- - - def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = { - Props(new RabitWorkerHandler(host, worldSize, tracker, connection)) - } -} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala deleted file mode 100644 index edec4931b..000000000 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala +++ /dev/null @@ -1,136 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.rabit.util - -import java.nio.{ByteBuffer, ByteOrder} - -/** - * The assigned rank to a connecting Rabit worker, along with the information of the ranks of - * its linked peer workers, which are critical to perform Allreduce. - * When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker - * client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond - * with this class as a message, which is later encoded to byte string, and sent over socket - * connection to the worker client. - * - * @param rank assigned rank (ranked by worker connection order: first worker connecting to the - * tracker is assigned rank 0, second with rank 1, etc.) - * @param neighbors ranks of neighboring workers in a tree map. - * @param ring ranks of neighboring workers in a ring map. - * @param parent rank of the parent worker. - */ -private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int], - ring: (Int, Int), parent: Int) { - /** - * Encode the AssignedRank message into byte sequence for socket communication with Rabit worker - * client. - * @param worldSize the number of total distributed workers. Must match `numWorkers` used in - * LinkMap. - * @return a ByteBuffer containing encoded data. - */ - def toByteBuffer(worldSize: Int): ByteBuffer = { - val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder()) - buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length) - // neighbors in tree structure - neighbors.foreach { n => buffer.putInt(n) } - buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1) - buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1) - - buffer.flip() - buffer - } -} - -private[rabit] class LinkMap(numWorkers: Int) { - private def getNeighbors(rank: Int): Seq[Int] = { - val rank1 = rank + 1 - Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r => - r >= 0 && r < numWorkers - } - } - - /** - * Construct a ring structure that tends to share nodes with the tree. - * - * @param treeMap - * @param parentMap - * @param rank - * @return Seq[Int] instance starting from rank. - */ - private def constructShareRing(treeMap: Map[Int, Seq[Int]], - parentMap: Map[Int, Int], - rank: Int = 0): Seq[Int] = { - treeMap(rank).toSet - parentMap(rank) match { - case emptySet if emptySet.isEmpty => - List(rank) - case connectionSet => - connectionSet.zipWithIndex.foldLeft(List(rank)) { - case (ringSeq, (v, cnt)) => - val vConnSeq = constructShareRing(treeMap, parentMap, v) - vConnSeq match { - case vconn if vconn.size == cnt + 1 => - ringSeq ++ vconn.reverse - case vconn => - ringSeq ++ vconn - } - } - } - } - /** - * Construct a ring connection used to recover local data. - * - * @param treeMap - * @param parentMap - */ - private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = { - assert(parentMap(0) == -1) - - val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector - assert(sharedRing.length == treeMap.size) - - (0 until numWorkers).map { r => - val rPrev = (r + numWorkers - 1) % numWorkers - val rNext = (r + 1) % numWorkers - sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext)) - }.toMap - } - - private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap - private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap - private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_) - val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) { - case ((rmap, k), i) => - val kNext = ringMap_(k)._2 - (rmap ++ Map(kNext -> (i + 1)), kNext) - }._1 - - val ringMap = ringMap_.map { - case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1)) - } - val treeMap = treeMap_.map { - case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) } - } - val parentMap = parentMap_.map { - case (k, v) if k == 0 => - rMap_(k) -> -1 - case (k, v) => - rMap_(k) -> rMap_(v) - } - - def assignRank(rank: Int): AssignedRank = { - AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank)) - } -} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala deleted file mode 100644 index 3d7be618d..000000000 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.rabit.util - -import java.nio.{ByteOrder, ByteBuffer} -import akka.util.ByteString - -private[rabit] object RabitTrackerHelpers { - implicit class ByteStringHelplers(bs: ByteString) { - // Java by default uses big endian. Enforce native endian so that - // the byte order is consistent with the workers. - def asNativeOrderByteBuffer: ByteBuffer = { - bs.asByteBuffer.order(ByteOrder.nativeOrder()) - } - } - - implicit class ByteBufferHelpers(buf: ByteBuffer) { - def getString: String = { - val len = buf.getInt() - val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder()) - buf.get(stringBuffer.array(), 0, len) - new String(stringBuffer.array(), "utf-8") - } - } -} diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala deleted file mode 100644 index cd9016812..000000000 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala +++ /dev/null @@ -1,255 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.rabit - -import java.nio.{ByteBuffer, ByteOrder} - -import akka.actor.{ActorRef, ActorSystem} -import akka.io.Tcp -import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe} -import akka.util.ByteString -import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler -import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._ -import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap -import org.junit.runner.RunWith -import org.scalatest.junit.JUnitRunner -import org.scalatest.{FlatSpecLike, Matchers} - -import scala.concurrent.Promise - -object RabitTrackerConnectionHandlerTest { - def intSeqToByteString(seq: Seq[Int]): ByteString = { - val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder()) - seq.foreach { i => buf.putInt(i) } - buf.flip() - ByteString.fromByteBuffer(buf) - } -} - -@RunWith(classOf[JUnitRunner]) -class RabitTrackerConnectionHandlerTest - extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest")) - with FlatSpecLike with Matchers with ImplicitSender { - - import RabitTrackerConnectionHandlerTest._ - - val magic = intSeqToByteString(List(0xff99)) - - "RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in { - val trackerProbe = TestProbe() - val connProbe = TestProbe() - - val worldSize = 4 - - val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize, - trackerProbe.ref, connProbe.ref)) - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake - - // send mock magic number - fsm ! Tcp.Received(magic) - connProbe.expectMsg(Tcp.Write(magic)) - - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand - fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand - // ResumeReading should be seen once state transitions - connProbe.expectMsg(Tcp.ResumeReading) - - // send mock tracker command in fragments: the handler should be able to handle it. - val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder()) - bufRank.putInt(0).putInt(worldSize).flip() - - val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder()) - bufJobId.putInt(1).put(Array[Byte]('0')).flip() - - val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder()) - bufCmd.putInt(5).put("start".getBytes()).flip() - - fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank)) - fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId)) - - // the state should not change for incomplete command data. - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand - - // send the last fragment, and expect message at tracker actor. - fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd)) - trackerProbe.expectMsg(WorkerStart(0, worldSize, "0")) - - val linkMap = new LinkMap(worldSize) - val assignedRank = linkMap.assignRank(0) - trackerProbe.reply(assignedRank) - - connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer( - assignedRank.toByteBuffer(worldSize) - ))) - - // reading should be suspended upon transitioning to BuildingLinkMap - connProbe.expectMsg(Tcp.SuspendReading) - // state should transition with according state data changes. - fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap - fsm.stateData shouldEqual RabitWorkerHandler.StructNodes - connProbe.expectMsg(Tcp.ResumeReading) - - // since the connection handler in test has rank 0, it will not have any nodes to connect to. - fsm ! Tcp.Received(intSeqToByteString(List(0))) - trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers)) - - // return mock response to the connection handler - val awaitConnPromise = Promise[AwaitingConnections]() - awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef], - fsm.underlyingActor.getNeighboringWorkers.size - )) - fsm ! awaitConnPromise.future - connProbe.expectMsg(Tcp.Write( - intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size)) - )) - connProbe.expectMsg(Tcp.SuspendReading) - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount - connProbe.expectMsg(Tcp.ResumeReading) - - // send mock error count (0) - fsm ! Tcp.Received(intSeqToByteString(List(0))) - - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber - connProbe.expectMsg(Tcp.ResumeReading) - - // simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events. - fsm ! Tcp.PeerClosed - // state should not transition - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber - fsm ! Tcp.Received(intSeqToByteString(List(32768))) - - fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete - connProbe.expectMsg(Tcp.ResumeReading) - - trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2)) - - val handlerStopProbe = TestProbe() - handlerStopProbe watch fsm - - // simulate connections from other workers by mocking ReduceWaitCount commands - fsm ! RabitWorkerHandler.ReduceWaitCount(1) - fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete - fsm ! RabitWorkerHandler.ReduceWaitCount(1) - trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0)) - handlerStopProbe.expectTerminated(fsm) - - // all done. - } - - it should "forward print command to tracker" in { - val trackerProbe = TestProbe() - val connProbe = TestProbe() - - val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4, - trackerProbe.ref, connProbe.ref)) - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake - - fsm ! Tcp.Received(magic) - connProbe.expectMsg(Tcp.Write(magic)) - - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand - fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand - // ResumeReading should be seen once state transitions - connProbe.expectMsg(Tcp.ResumeReading) - - val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!") - fsm ! Tcp.Received(printCmd.encode) - - trackerProbe.expectMsg(printCmd) - } - - it should "handle fragmented print command without throwing exception" in { - val trackerProbe = TestProbe() - val connProbe = TestProbe() - - val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4, - trackerProbe.ref, connProbe.ref)) - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake - - fsm ! Tcp.Received(magic) - connProbe.expectMsg(Tcp.Write(magic)) - - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand - fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand - // ResumeReading should be seen once state transitions - connProbe.expectMsg(Tcp.ResumeReading) - - val printCmd = WorkerTrackerPrint(0, 4, "0", "fragmented!") - // 4 (rank: Int) + 4 (worldSize: Int) + (4+1) (jobId: String) + (4+5) (command: String) = 22 - val (partialMessage, remainder) = printCmd.encode.splitAt(22) - - // make sure that the partialMessage in itself is a valid command - val partialMsgBuf = ByteBuffer.allocate(22).order(ByteOrder.nativeOrder()) - partialMsgBuf.put(partialMessage.asByteBuffer) - RabitWorkerHandler.StructTrackerCommand.verify(partialMsgBuf) shouldBe true - - fsm ! Tcp.Received(partialMessage) - fsm ! Tcp.Received(remainder) - - trackerProbe.expectMsg(printCmd) - } - - it should "handle spill-over Tcp data correctly between state transition" in { - val trackerProbe = TestProbe() - val connProbe = TestProbe() - - val worldSize = 4 - - val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize, - trackerProbe.ref, connProbe.ref)) - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake - - // send mock magic number - fsm ! Tcp.Received(magic) - connProbe.expectMsg(Tcp.Write(magic)) - - fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand - fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand - // ResumeReading should be seen once state transitions - connProbe.expectMsg(Tcp.ResumeReading) - - // send mock tracker command in fragments: the handler should be able to handle it. - val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder()) - bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0')) - .putInt(5).put("start".getBytes()) - // spilled-over data - .putInt(0).flip() - - // send data with 4 extra bytes corresponding to the next state. - fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd)) - - trackerProbe.expectMsg(WorkerStart(0, worldSize, "0")) - - val linkMap = new LinkMap(worldSize) - val assignedRank = linkMap.assignRank(0) - trackerProbe.reply(assignedRank) - - connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer( - assignedRank.toByteBuffer(worldSize) - ))) - - // reading should be suspended upon transitioning to BuildingLinkMap - connProbe.expectMsg(Tcp.SuspendReading) - // state should transition with according state data changes. - fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap - fsm.stateData shouldEqual RabitWorkerHandler.StructNodes - connProbe.expectMsg(Tcp.ResumeReading) - - // the handler should be able to handle spill-over data, and stash it until state transition. - trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers)) - } -}