[breaking] [jvm-packages] Remove scala-implemented tracker. (#9045)

This commit is contained in:
Jiaming Yuan 2023-04-20 16:29:35 +08:00 committed by GitHub
parent 42d100de18
commit 564df59204
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 9 additions and 1585 deletions

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -23,7 +23,6 @@ import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@ -44,21 +43,16 @@ import org.apache.spark.sql.SparkSession
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
*/
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
case class TrackerConf(workerConnectionTimeout: Long,
hostIp: String = "", pythonExec: String = "")
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
def apply(): TrackerConf = TrackerConf(0L)
}
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
@ -349,11 +343,9 @@ object XGBoost extends Serializable {
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = trackerConf.trackerImpl match {
case "scala" => new RabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
case _ => new PyRabitTracker(nWorkers)
}
val tracker: IRabitTracker = new PyRabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
)
tracker
}

View File

@ -22,7 +22,6 @@ import scala.util.Random
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.scalatest.FunSuite
@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
"tracker_conf" -> TrackerConf(0L, hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap1 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap2 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
}
}
test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test Communicator allreduce to validate Scala-implemented Rabit tracker") {
val vectorLength = 100
val rdd = sc.parallelize(
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
val rawData = rdd.mapPartitions { iter =>
Iterator(iter.toArray)
}.collect()
val maxVec = (0 until vectorLength).toArray.map { j =>
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
}
val allReduceResults = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
val arr = iter.toArray
val results = Communicator.allReduce(arr, Communicator.OpType.MAX)
Communicator.shutdown()
Iterator(results)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
allReduceResults.foreachPartition(() => _)
val byPartitionResults = allReduceResults.collect()
assert(byPartitionResults(0).length == vectorLength)
collectedAllReduceResults.put(byPartitionResults(0))
}
}
sparkThread.start()
assert(tracker.waitFor(0L) == 0)
sparkThread.join()
assert(collectedAllReduceResults.poll().sameElements(maxVec))
}
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
assert(tracker.waitFor(0) != 0)
}
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
val dummyTasks = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Communicator.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("test Scala RabitTracker's workerConnectionTimeout") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(500)
val trackerEnvs = tracker.getWorkerEnvs
val dummyTasks = rdd.mapPartitions { iter =>
val index = iter.next()
// simulate that the first worker cannot connect to tracker due to network issues.
if (index != 1) {
Communicator.init(trackerEnvs)
Thread.sleep(1000)
Communicator.shutdown()
}
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
// should fail due to connection timeout
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
" multiple times (ISSUE-4406)") {
val paramMap = Map(

View File

@ -1,195 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit
import java.net.{InetAddress, InetSocketAddress}
import akka.actor.ActorSystem
import akka.pattern.ask
import ml.dmlc.xgboost4j.java.{IRabitTracker, TrackerProperties}
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.{Failure, Success, Try}
/**
* Scala implementation of the Rabit tracker interface without Python dependency.
* The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
* (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
* failures.
*
* Note that this implementation is currently experimental, and should be used at your own risk.
*
* Example usage:
* {{{
* import scala.concurrent.duration._
*
* val tracker = new RabitTracker(32)
* // allow up to 10 minutes for all workers to connect to the tracker.
* tracker.start(10 minutes)
*
* /* ...
* launching workers in parallel
* ...
* */
*
* // wait for worker execution up to 6 hours.
* // providing a finite timeout prevents a long-running task from hanging forever in
* // catastrophic events, like the loss of an executor during model training.
* tracker.waitFor(6 hours)
* }}}
*
* @param numWorkers Number of distributed workers from which the tracker expects connections.
* @param port The minimum port number that the tracker binds to.
* If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
* @param maxPortTrials The maximum number of trials of socket binding, by sequentially
* increasing the port number.
*/
private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
maxPortTrials: Int = 1000)
extends IRabitTracker {
import scala.collection.JavaConverters._
require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
val system = ActorSystem.create("RabitTracker")
val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
private[this] val tcpBindingTimeout: Duration = 1 minute
var workerEnvs: Map[String, String] = Map.empty
override def uncaughtException(t: Thread, e: Throwable): Unit = {
handler ? RabitTrackerHandler.InterruptTracker(e)
}
/**
* Start the Rabit tracker.
*
* @param timeout The timeout for awaiting connections from worker nodes.
* Note that when used in Spark applications, because all Spark transformations are
* lazily executed, the I/O time for loading RDDs/DataFrames from external sources
* (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
* If the timeout value is too small, the Rabit tracker will likely timeout before workers
* establishing connections to the tracker, due to the overhead of loading data.
* Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
* running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
* @return Boolean flag indicating if the Rabit tracker starts successfully.
*/
private def start(timeout: Duration): Boolean = {
val hostAddress = Option(TrackerProperties.getInstance().getHostIp)
.map(InetAddress.getByName).getOrElse(InetAddress.getLocalHost)
handler ? RabitTrackerHandler.StartTracker(
new InetSocketAddress(hostAddress, port.getOrElse(0)), maxPortTrials, timeout)
// block by waiting for the actor to bind to a port
Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
.asInstanceOf[Future[Map[String, String]]]) match {
case Success(futurePortBound) =>
// The success of the Future is contingent on binding to an InetSocketAddress.
val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
if (isBound) {
workerEnvs = Await.result(futurePortBound, 0 nano)
}
isBound
case Failure(ex: Throwable) =>
false
}
}
/**
* Start the Rabit tracker.
*
* @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
* connections. If a non-positive value is provided, the tracker
* waits for incoming worker connections indefinitely.
* @return Boolean flag indicating if the Rabit tracker starts successfully.
*/
def start(connectionTimeoutMillis: Long): Boolean = {
if (connectionTimeoutMillis <= 0) {
start(Duration.Inf)
} else {
start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
}
}
def stop(): Unit = {
system.terminate()
}
/**
* Get a Map of necessary environment variables to initiate Rabit workers.
*
* @return HashMap containing tracker information.
*/
def getWorkerEnvs: java.util.Map[String, String] = {
new java.util.HashMap((workerEnvs ++ Map(
"DMLC_NUM_WORKER" -> numWorkers.toString,
"DMLC_NUM_SERVER" -> "0"
)).asJava)
}
/**
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
* This method blocks until timeout or task completion.
*
* @param atMost the maximum execution time for the workers. By default,
* the tracker waits for the workers indefinitely.
* @return 0 if the tasks complete successfully, and non-zero otherwise.
*/
private def waitFor(atMost: Duration): Int = {
// request the completion Future from the tracker actor
Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
.asInstanceOf[Future[Int]]) match {
case Success(futureCompleted) =>
// wait for all workers to complete synchronously.
val statusCode = Try(Await.result(futureCompleted, atMost)) match {
case Success(n) if n == numWorkers =>
IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
case Success(n) if n < numWorkers =>
IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
case Failure(e) =>
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
}
system.terminate()
statusCode
case Failure(ex: Throwable) =>
system.terminate()
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
}
}
/**
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
* This method blocks until timeout or task completion.
*
* @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
* non-positive number is given, the tracker waits indefinitely.
* @return 0 if the tasks complete successfully, and non-zero otherwise
*/
def waitFor(atMostMillis: Long): Int = {
if (atMostMillis <= 0) {
waitFor(Duration.Inf)
} else {
waitFor(Duration.fromNanos(atMostMillis * 1e6))
}
}
}

View File

@ -1,361 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.handler
import java.net.InetSocketAddress
import java.util.UUID
import scala.concurrent.duration._
import scala.collection.mutable
import scala.concurrent.{Promise, TimeoutException}
import akka.io.{IO, Tcp}
import akka.actor._
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
import scala.util.{Failure, Random, Success, Try}
/** The Akka actor for handling and coordinating Rabit worker connections.
* This is the main actor for handling socket connections, interacting with the synchronous
* tracker interface, and resolving tree/ring/parent dependencies between workers.
*
* @param numWorkers Number of workers to track.
*/
private[scala] class RabitTrackerHandler(numWorkers: Int)
extends Actor with ActorLogging {
import context.system
import RabitWorkerHandler._
import RabitTrackerHandler._
private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
private[this] val promisedShutdownWorkers = Promise[Int]()
private[this] val tcpManager = IO(Tcp)
// resolves worker connection dependency.
val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
// workers that have sent "shutdown" signal
private[this] val shutdownWorkers = mutable.Set.empty[Int]
private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
private[this] var maxPortTrials = 0
private[this] var workerConnectionTimeout: Duration = Duration.Inf
private[this] var portTrials = 0
private[this] val startedWorkers = mutable.Set.empty[Int]
val linkMap = new LinkMap(numWorkers)
def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
rank match {
case r if r >= 0 => Some(r)
case _ =>
jobId match {
case "NULL" => None
case jid => jobToRankMap.get(jid)
}
}
}
/**
* Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
* by the RabitWorkerHandler.
*
* @param event Generic Tcp.Event
*/
private def handleTcpEvents(event: Tcp.Event): Unit = event match {
case Tcp.Bound(local) =>
// expect all workers to connect within timeout
log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
log.info(s"Worker connection timeout is $workerConnectionTimeout.")
context.setReceiveTimeout(workerConnectionTimeout)
promisedWorkerEnvs.success(Map(
"DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
"DMLC_TRACKER_PORT" -> local.getPort.toString,
// not required because the world size will be communicated to the
// worker node after the rank is assigned.
"rabit_world_size" -> numWorkers.toString
))
case Tcp.CommandFailed(cmd: Tcp.Bind) =>
if (portTrials < maxPortTrials) {
portTrials += 1
tcpManager ! Tcp.Bind(self,
new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
backlog = 256)
}
case Tcp.Connected(remote, local) =>
log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
// revoke timeout if all workers have connected.
val workerHandler = context.actorOf(RabitWorkerHandler.props(
remote.getAddress.getHostAddress, numWorkers, self, sender()
), s"ConnectionHandler-${UUID.randomUUID().toString}")
val connection = sender()
connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
}
/**
* Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
* to interact with the tracker interface.
*
* @param trackerMsg control messages sent by RabitTracker class.
*/
private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
trackerMsg match {
case msg: StartTracker =>
maxPortTrials = msg.maxPortTrials
workerConnectionTimeout = msg.connectionTimeout
// if the port number is missing, try binding to a random ephemeral port.
if (msg.addr.getPort == 0) {
tcpManager ! Tcp.Bind(self,
new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
backlog = 256)
} else {
tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
}
sender() ! true
case RequestBoundFuture =>
sender() ! promisedWorkerEnvs.future
case RequestCompletionFuture =>
sender() ! promisedShutdownWorkers.future
case InterruptTracker(e) =>
log.error(e, "Uncaught exception thrown by worker.")
// make sure that waitFor() does not hang indefinitely.
promisedShutdownWorkers.failure(e)
context.stop(self)
}
/**
* Handles messages sent by child actors representing connecting Rabit workers, by brokering
* messages to the dependency resolver, and processing worker commands.
*
* @param workerMsg Message sent by RabitWorkerHandler actors.
*/
private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
case req @ RequestAwaitConnWorkers(_, _) =>
// since the requester may request to connect to other workers
// that have not fully set up, delegate this request to the
// dependency resolver which handles the dependencies properly.
resolver forward req
// ---- Rabit worker commands: start/recover/shutdown/print ----
case WorkerTrackerPrint(_, _, _, msg) =>
log.info(msg.trim)
case WorkerShutdown(rank, _, _) =>
assert(rank >= 0, "Invalid rank.")
assert(!shutdownWorkers.contains(rank))
shutdownWorkers.add(rank)
log.info(s"Received shutdown signal from $rank")
if (shutdownWorkers.size == numWorkers) {
promisedShutdownWorkers.success(shutdownWorkers.size)
}
case WorkerRecover(prevRank, worldSize, jobId) =>
assert(prevRank >= 0)
sender() ! linkMap.assignRank(prevRank)
case WorkerStart(rank, worldSize, jobId) =>
assert(worldSize == numWorkers || worldSize == -1,
s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
)
Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
case Success(wkRank) =>
if (jobId != "NULL") {
jobToRankMap.put(jobId, wkRank)
}
val assignedRank = linkMap.assignRank(wkRank)
sender() ! assignedRank
resolver ! assignedRank
log.info("Received start signal from " +
s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
case Failure(ex: IndexOutOfBoundsException) =>
// More than worldSize workers have connected, likely due to executor loss.
// Since Rabit currently does not support crash recovery (because the Allreduce results
// are not cached by the tracker, and because existing workers cannot reestablish
// connections to newly spawned executor/worker), the most reasonble action here is to
// interrupt the tracker immediate with failure state.
log.error("Received invalid start signal from " +
s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
)
promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
" received from worker, likely due to executor loss."))
case Failure(ex) =>
log.error(ex, "Unexpected error")
promisedShutdownWorkers.failure(ex)
}
// ---- Dependency resolving related messages ----
case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
log.info(s"Worker $host (rank: $rank) has started.")
resolver forward msg
startedWorkers.add(rank)
if (startedWorkers.size == numWorkers) {
log.info("All workers have started.")
}
case req @ DropFromWaitingList(_) =>
// all peer workers in dependency link map have connected;
// forward message to resolver to update dependencies.
resolver forward req
case _ =>
}
def receive: Actor.Receive = {
case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
case akka.actor.ReceiveTimeout =>
if (startedWorkers.size < numWorkers) {
promisedShutdownWorkers.failure(
new TimeoutException("Timed out waiting for workers to connect: " +
s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
)
context.stop(self)
}
context.setReceiveTimeout(Duration.Undefined)
}
}
/**
* Resolve the dependency between nodes as they connect to the tracker.
* The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
* and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
* connections by workers, this dependency constraint assumes that a worker node connects first
* is likely to finish setup first.
*/
private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
import RabitWorkerHandler._
context.watch(handler)
case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
// worker nodes that have connected, but have not send WorkerStarted message.
private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
private val startedWorkers = mutable.Set.empty[Int]
// worker nodes that have started, and await for connections.
private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
val connSet = awaitConnWorkers.toMap
.filterKeys(k => linkSet.contains(k))
AwaitingConnections(connSet, linkSet.size - connSet.size)
}
def receive: Actor.Receive = {
// a copy of the AssignedRank message that is also sent to the worker
case AssignedRank(rank, tree_neighbors, ring, parent) =>
// the workers that the worker of given `rank` depends on:
// worker of rank K only depends on workers with rank smaller than K.
val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
.filter{ r => r != -1 && r < rank}
log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
dependencyMap.put(rank, dependentWorkers)
case RequestAwaitConnWorkers(rank, toConnectSet) =>
val promise = Promise[AwaitingConnections]()
assert(dependencyMap.contains(rank))
val updatedDependency = dependencyMap(rank) diff startedWorkers
if (updatedDependency.isEmpty) {
// all dependencies are satisfied
log.debug(s"Rank $rank has all dependencies satisfied.")
promise.success(awaitingWorkers(toConnectSet))
} else {
log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
// promise is pending fulfillment due to unresolved dependency
pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
}
sender() ! promise.future
case WorkerStarted(_, started, awaitingAcceptance) =>
startedWorkers.add(started)
if (awaitingAcceptance > 0) {
awaitConnWorkers.put(started, sender())
}
// remove the started rank from all dependencies.
dependencyMap.remove(started)
dependencyMap.foreach { case (r, dset) =>
val updatedDependency = dset diff startedWorkers
// fulfill the future if all dependencies are met (started.)
if (updatedDependency.isEmpty) {
log.debug(s"Rank $r has all dependencies satisfied.")
pendingFulfillment.remove(r).map{
case Fulfillment(toConnectSet, promise) =>
promise.success(awaitingWorkers(toConnectSet))
}
}
dependencyMap.update(r, updatedDependency)
}
case DropFromWaitingList(rank) =>
assert(awaitConnWorkers.remove(rank).isDefined)
case Terminated(ref) =>
if (ref.equals(handler)) {
context.stop(self)
}
}
}
private[scala] object RabitTrackerHandler {
// Messages sent by RabitTracker to this RabitTrackerHandler actor
trait TrackerControlMessage
case object RequestCompletionFuture extends TrackerControlMessage
case object RequestBoundFuture extends TrackerControlMessage
// Start the Rabit tracker at given socket address awaiting worker connections.
// All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
// shut down due to timeout.
case class StartTracker(addr: InetSocketAddress,
maxPortTrials: Int,
connectionTimeout: Duration) extends TrackerControlMessage
// To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
// driver for the distributed training.
case class InterruptTracker(e: Throwable) extends TrackerControlMessage
def props(numWorkers: Int): Props =
Props(new RabitTrackerHandler(numWorkers))
}

View File

@ -1,467 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.handler
import java.nio.{ByteBuffer, ByteOrder}
import akka.io.Tcp
import akka.actor._
import akka.util.ByteString
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Try
/**
* Actor to handle socket communication from worker node.
* To handle fragmentation in received data, this class acts like a FSM
* (finite-state machine) to keep track of the internal states.
*
* @param host IP address of the remote worker
* @param worldSize number of total workers
* @param tracker the RabitTrackerHandler actor reference
*/
private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
connection: ActorRef)
extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
with ActorLogging with Stash {
import RabitWorkerHandler._
import RabitTrackerHelpers._
private[this] var rank: Int = 0
private[this] var port: Int = 0
// indicate if the connection is transient (like "print" or "shutdown")
private[this] var transient: Boolean = false
private[this] var peerClosed: Boolean = false
// number of workers pending acceptance of current worker
private[this] var awaitingAcceptance: Int = 0
private[this] var neighboringWorkers = Set.empty[Int]
// TODO: use a single memory allocation to host all buffers,
// including the transient ones for writing.
private[this] val readBuffer = ByteBuffer.allocate(4096)
.order(ByteOrder.nativeOrder())
// in case the received message is longer than needed,
// stash the spilled over part in this buffer, and send
// to self when transition occurs.
private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
.order(ByteOrder.nativeOrder())
// when setup is complete, need to notify peer handlers
// to reduce the awaiting-connection counter.
private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
private def resetBuffers(): Unit = {
readBuffer.clear()
if (spillOverBuffer.position() > 0) {
spillOverBuffer.flip()
self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
spillOverBuffer.clear()
}
}
private def stashSpillOver(buf: ByteBuffer): Unit = {
if (buf.remaining() > 0) spillOverBuffer.put(buf)
}
def getNeighboringWorkers: Set[Int] = neighboringWorkers
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
readBuffer.flip()
val rank = readBuffer.getInt()
val worldSize = readBuffer.getInt()
val jobId = readBuffer.getString
val command = readBuffer.getString
val trackerCommand = command match {
case "start" => WorkerStart(rank, worldSize, jobId)
case "shutdown" =>
transient = true
WorkerShutdown(rank, worldSize, jobId)
case "recover" =>
require(rank >= 0, "Invalid rank for recovering worker.")
WorkerRecover(rank, worldSize, jobId)
case "print" =>
transient = true
WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString)
}
stashSpillOver(readBuffer)
trackerCommand
}
startWith(AwaitingHandshake, DataStruct())
when(AwaitingHandshake) {
case Event(Tcp.Received(magic), _) =>
assert(magic.length == 4)
val purportedMagic = magic.asNativeOrderByteBuffer.getInt
assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
// echo back the magic number
connection ! Tcp.Write(magic)
goto(AwaitingCommand) using StructTrackerCommand
}
when(AwaitingCommand) {
case Event(Tcp.Received(bytes), validator) =>
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
if (validator.verify(readBuffer)) {
Try(decodeCommand(readBuffer)) match {
case scala.util.Success(decodedCommand) =>
tracker ! decodedCommand
case scala.util.Failure(th: java.nio.BufferUnderflowException) =>
// BufferUnderflowException would occur if the message to print has not arrived yet.
// Do nothing, wait for next Tcp.Received event
case scala.util.Failure(th: Throwable) => throw th
}
}
stay
// when rank for a worker is assigned, send encoded rank information
// back to worker over Tcp socket.
case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
rank = assignedRank
// ranks from the ring
val ringRanks = List(
// ringPrev
if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
// ringNext
if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
)
// update the set of all linked workers to current worker.
neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
// to prevent reading before state transition
connection ! Tcp.SuspendReading
goto(BuildingLinkMap) using StructNodes
}
when(BuildingLinkMap) {
case Event(Tcp.Received(bytes), validator) =>
bytes.asByteBuffers.foreach { buf =>
readBuffer.put(buf)
}
if (validator.verify(readBuffer)) {
readBuffer.flip()
// for a freshly started worker, numConnected should be 0.
val numConnected = readBuffer.getInt()
val toConnectSet = neighboringWorkers.diff(
(0 until numConnected).map { index => readBuffer.getInt() }.toSet)
// check which workers are currently awaiting connections
tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
}
stay
// got a Future from the tracker (resolver) about workers that are
// currently awaiting connections (particularly from this node.)
case Event(future: Future[_], _) =>
// blocks execution until all dependencies for current worker is resolved.
Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
// numNotReachable is the number of workers that currently
// cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
// connections from those currently non-reachable nodes in the future.
case AwaitingConnections(waitConnNodes, numNotReachable) =>
log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
buf.putInt(waitConnNodes.size).putInt(numNotReachable)
buf.flip()
// cache this message until the final state (SetupComplete)
pendingAcknowledgement = Some(AcknowledgeAcceptance(
waitConnNodes, numNotReachable))
connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
if (waitConnNodes.isEmpty) {
connection ! Tcp.SuspendReading
goto(AwaitingErrorCount)
}
else {
waitConnNodes.foreach { case (peerRank, peerRef) =>
peerRef ! RequestWorkerHostPort
}
// a countdown for DivulgedHostPort messages.
stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
}
}
case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
val hostBytes = peerHost.getBytes()
val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
.order(ByteOrder.nativeOrder())
buffer.putInt(peerHost.length).put(hostBytes)
.putInt(peerPort).putInt(peerRank)
buffer.flip()
connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
if (data.counter == 0) {
// to prevent reading before state transition
connection ! Tcp.SuspendReading
goto(AwaitingErrorCount)
}
else {
stay using data.decrement()
}
}
when(AwaitingErrorCount) {
case Event(Tcp.Received(numErrors), _) =>
val buf = numErrors.asNativeOrderByteBuffer
buf.getInt match {
case 0 =>
stashSpillOver(buf)
goto(AwaitingPortNumber)
case _ =>
stashSpillOver(buf)
goto(BuildingLinkMap) using StructNodes
}
}
when(AwaitingPortNumber) {
case Event(Tcp.Received(assignedPort), _) =>
assert(assignedPort.length == 4)
port = assignedPort.asNativeOrderByteBuffer.getInt
log.debug(s"Rank $rank listening @ $host:$port")
// wait until the worker closes connection.
if (peerClosed) goto(SetupComplete) else stay
case Event(Tcp.PeerClosed, _) =>
peerClosed = true
if (port == 0) stay else goto(SetupComplete)
}
when(SetupComplete) {
case Event(ReduceWaitCount(count: Int), _) =>
awaitingAcceptance -= count
// check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
if (awaitingAcceptance == 0 && peerClosed) {
tracker ! DropFromWaitingList(rank)
// no longer needed.
context.stop(self)
}
stay
case Event(AcknowledgeAcceptance(peers, numBad), _) =>
awaitingAcceptance = numBad
tracker ! WorkerStarted(host, rank, awaitingAcceptance)
peers.values.foreach { peer =>
peer ! ReduceWaitCount(1)
}
if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
stay
// can only divulge the complete host and port information
// when this worker is declared fully connected (otherwise
// port information is still missing.)
case Event(RequestWorkerHostPort, _) =>
sender() ! DivulgedWorkerHostPort(rank, host, port)
stay
}
onTransition {
// reset buffer when state transitions as data becomes stale
case _ -> SetupComplete =>
connection ! Tcp.ResumeReading
resetBuffers()
if (pendingAcknowledgement.isDefined) {
self ! pendingAcknowledgement.get
}
case _ =>
connection ! Tcp.ResumeReading
resetBuffers()
}
// default message handler
whenUnhandled {
case Event(Tcp.PeerClosed, _) =>
peerClosed = true
if (transient) context.stop(self)
stay
}
}
private[scala] object RabitWorkerHandler {
val MAGIC_NUMBER = 0xff99
// Finite states of this actor, which acts like a FSM.
// The following states are defined in order as the FSM progresses.
sealed trait State
// [1] Initial state, awaiting worker to send magic number per protocol.
case object AwaitingHandshake extends State
// [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
case object AwaitingCommand extends State
// [3] Brokers connections between workers per ring/tree/parent link map.
case object BuildingLinkMap extends State
// [4] A transient state in which the worker reports the number of errors in establishing
// connections to other peer workers. If no errors, transition to next state.
case object AwaitingErrorCount extends State
// [5] Awaiting the worker to report its port number for accepting connections from peer workers.
// This port number information is later forwarded to linked workers.
case object AwaitingPortNumber extends State
// [6] Final state after completing the setup with the connecting worker. At this stage, the
// worker will have closed the Tcp connection. The actor remains alive to handle messages from
// peer actors representing workers with pending setups.
case object SetupComplete extends State
sealed trait DataField
case object IntField extends DataField
// an integer preceding the actual string
case object StringField extends DataField
case object IntSeqField extends DataField
object DataStruct {
def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
}
// Internal data pertaining to individual state, used to verify the validity of packets sent by
// workers.
case class DataStruct(fields: Seq[DataField], counter: Int) {
/**
* Validate whether the provided buffer is complete (i.e., contains
* all data fields specified for this DataStruct.)
*
* @param buf a byte buffer containing received data.
*/
def verify(buf: ByteBuffer): Boolean = {
if (fields.isEmpty) return true
val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
dupBuf.flip()
Try(fields.foldLeft(true) {
case (complete, field) =>
val remBytes = dupBuf.remaining()
complete && (remBytes > 0) && (remBytes >= (field match {
case IntField =>
dupBuf.position(dupBuf.position() + 4)
4
case StringField =>
val strLen = dupBuf.getInt
dupBuf.position(dupBuf.position() + strLen)
4 + strLen
case IntSeqField =>
val seqLen = dupBuf.getInt
dupBuf.position(dupBuf.position() + seqLen * 4)
4 + seqLen * 4
}))
}).getOrElse(false)
}
def increment(): DataStruct = DataStruct(fields, counter + 1)
def decrement(): DataStruct = DataStruct(fields, counter - 1)
}
val StructNodes = DataStruct(List(IntSeqField), 0)
val StructTrackerCommand = DataStruct(List(
IntField, IntField, StringField, StringField
), 0)
// ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
// RabitWorkerHandler --> RabitTrackerHandler
sealed trait RabitWorkerRequest
// RabitWorkerHandler <-- RabitTrackerHandler
sealed trait RabitWorkerResponse
// Representations of decoded worker commands.
abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
def rank: Int
def worldSize: Int
def jobId: String
def encode: ByteString = {
val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
.order(ByteOrder.nativeOrder())
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
.putInt(command.length).put(command.getBytes()).flip()
ByteString.fromByteBuffer(buf)
}
}
case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("start")
case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("shutdown")
case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("recover")
case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
extends TrackerCommand("print") {
override def encode: ByteString = {
val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
.order(ByteOrder.nativeOrder())
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
.putInt(command.length).put(command.getBytes())
.putInt(msg.length).put(msg.getBytes()).flip()
ByteString.fromByteBuffer(buf)
}
}
// Request to remove the worker of given rank from the list of workers awaiting peer connections.
case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
// Notify the tracker that the worker of given rank has finished setup and started.
case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
extends RabitWorkerRequest
// Request the set of workers to connect to, according to the LinkMap structure.
case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
extends RabitWorkerRequest
// Request, from the tracker, the set of nodes to connect.
case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
extends RabitWorkerResponse
// ---- Messages between ConnectionHandler actors ----
sealed trait IntraWorkerMessage
// Notify neighboring workers to decrease the counter of awaiting workers by `count`.
case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
// Request host and port information from peer ConnectionHandler actors (acting on behave of
// connecting workers.) This message will be brokered by RabitTrackerHandler.
case object RequestWorkerHostPort extends IntraWorkerMessage
// Response to the above request
case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
// A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
extends IntraWorkerMessage
// ---- End of message definitions ----
def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
}
}

View File

@ -1,136 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.util
import java.nio.{ByteBuffer, ByteOrder}
/**
* The assigned rank to a connecting Rabit worker, along with the information of the ranks of
* its linked peer workers, which are critical to perform Allreduce.
* When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
* client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
* with this class as a message, which is later encoded to byte string, and sent over socket
* connection to the worker client.
*
* @param rank assigned rank (ranked by worker connection order: first worker connecting to the
* tracker is assigned rank 0, second with rank 1, etc.)
* @param neighbors ranks of neighboring workers in a tree map.
* @param ring ranks of neighboring workers in a ring map.
* @param parent rank of the parent worker.
*/
private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
ring: (Int, Int), parent: Int) {
/**
* Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
* client.
* @param worldSize the number of total distributed workers. Must match `numWorkers` used in
* LinkMap.
* @return a ByteBuffer containing encoded data.
*/
def toByteBuffer(worldSize: Int): ByteBuffer = {
val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
// neighbors in tree structure
neighbors.foreach { n => buffer.putInt(n) }
buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
buffer.flip()
buffer
}
}
private[rabit] class LinkMap(numWorkers: Int) {
private def getNeighbors(rank: Int): Seq[Int] = {
val rank1 = rank + 1
Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
r >= 0 && r < numWorkers
}
}
/**
* Construct a ring structure that tends to share nodes with the tree.
*
* @param treeMap
* @param parentMap
* @param rank
* @return Seq[Int] instance starting from rank.
*/
private def constructShareRing(treeMap: Map[Int, Seq[Int]],
parentMap: Map[Int, Int],
rank: Int = 0): Seq[Int] = {
treeMap(rank).toSet - parentMap(rank) match {
case emptySet if emptySet.isEmpty =>
List(rank)
case connectionSet =>
connectionSet.zipWithIndex.foldLeft(List(rank)) {
case (ringSeq, (v, cnt)) =>
val vConnSeq = constructShareRing(treeMap, parentMap, v)
vConnSeq match {
case vconn if vconn.size == cnt + 1 =>
ringSeq ++ vconn.reverse
case vconn =>
ringSeq ++ vconn
}
}
}
}
/**
* Construct a ring connection used to recover local data.
*
* @param treeMap
* @param parentMap
*/
private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
assert(parentMap(0) == -1)
val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
assert(sharedRing.length == treeMap.size)
(0 until numWorkers).map { r =>
val rPrev = (r + numWorkers - 1) % numWorkers
val rNext = (r + 1) % numWorkers
sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
}.toMap
}
private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
case ((rmap, k), i) =>
val kNext = ringMap_(k)._2
(rmap ++ Map(kNext -> (i + 1)), kNext)
}._1
val ringMap = ringMap_.map {
case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
}
val treeMap = treeMap_.map {
case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
}
val parentMap = parentMap_.map {
case (k, v) if k == 0 =>
rMap_(k) -> -1
case (k, v) =>
rMap_(k) -> rMap_(v)
}
def assignRank(rank: Int): AssignedRank = {
AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
}
}

View File

@ -1,39 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.util
import java.nio.{ByteOrder, ByteBuffer}
import akka.util.ByteString
private[rabit] object RabitTrackerHelpers {
implicit class ByteStringHelplers(bs: ByteString) {
// Java by default uses big endian. Enforce native endian so that
// the byte order is consistent with the workers.
def asNativeOrderByteBuffer: ByteBuffer = {
bs.asByteBuffer.order(ByteOrder.nativeOrder())
}
}
implicit class ByteBufferHelpers(buf: ByteBuffer) {
def getString: String = {
val len = buf.getInt()
val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
buf.get(stringBuffer.array(), 0, len)
new String(stringBuffer.array(), "utf-8")
}
}
}

View File

@ -1,255 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit
import java.nio.{ByteBuffer, ByteOrder}
import akka.actor.{ActorRef, ActorSystem}
import akka.io.Tcp
import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe}
import akka.util.ByteString
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._
import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpecLike, Matchers}
import scala.concurrent.Promise
object RabitTrackerConnectionHandlerTest {
def intSeqToByteString(seq: Seq[Int]): ByteString = {
val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder())
seq.foreach { i => buf.putInt(i) }
buf.flip()
ByteString.fromByteBuffer(buf)
}
}
@RunWith(classOf[JUnitRunner])
class RabitTrackerConnectionHandlerTest
extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest"))
with FlatSpecLike with Matchers with ImplicitSender {
import RabitTrackerConnectionHandlerTest._
val magic = intSeqToByteString(List(0xff99))
"RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val worldSize = 4
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
// send mock magic number
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
// send mock tracker command in fragments: the handler should be able to handle it.
val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
bufRank.putInt(0).putInt(worldSize).flip()
val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder())
bufJobId.putInt(1).put(Array[Byte]('0')).flip()
val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder())
bufCmd.putInt(5).put("start".getBytes()).flip()
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank))
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId))
// the state should not change for incomplete command data.
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
// send the last fragment, and expect message at tracker actor.
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
val linkMap = new LinkMap(worldSize)
val assignedRank = linkMap.assignRank(0)
trackerProbe.reply(assignedRank)
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
assignedRank.toByteBuffer(worldSize)
)))
// reading should be suspended upon transitioning to BuildingLinkMap
connProbe.expectMsg(Tcp.SuspendReading)
// state should transition with according state data changes.
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
connProbe.expectMsg(Tcp.ResumeReading)
// since the connection handler in test has rank 0, it will not have any nodes to connect to.
fsm ! Tcp.Received(intSeqToByteString(List(0)))
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
// return mock response to the connection handler
val awaitConnPromise = Promise[AwaitingConnections]()
awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef],
fsm.underlyingActor.getNeighboringWorkers.size
))
fsm ! awaitConnPromise.future
connProbe.expectMsg(Tcp.Write(
intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size))
))
connProbe.expectMsg(Tcp.SuspendReading)
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount
connProbe.expectMsg(Tcp.ResumeReading)
// send mock error count (0)
fsm ! Tcp.Received(intSeqToByteString(List(0)))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
connProbe.expectMsg(Tcp.ResumeReading)
// simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events.
fsm ! Tcp.PeerClosed
// state should not transition
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
fsm ! Tcp.Received(intSeqToByteString(List(32768)))
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
connProbe.expectMsg(Tcp.ResumeReading)
trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2))
val handlerStopProbe = TestProbe()
handlerStopProbe watch fsm
// simulate connections from other workers by mocking ReduceWaitCount commands
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0))
handlerStopProbe.expectTerminated(fsm)
// all done.
}
it should "forward print command to tracker" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
fsm ! Tcp.Received(printCmd.encode)
trackerProbe.expectMsg(printCmd)
}
it should "handle fragmented print command without throwing exception" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
val printCmd = WorkerTrackerPrint(0, 4, "0", "fragmented!")
// 4 (rank: Int) + 4 (worldSize: Int) + (4+1) (jobId: String) + (4+5) (command: String) = 22
val (partialMessage, remainder) = printCmd.encode.splitAt(22)
// make sure that the partialMessage in itself is a valid command
val partialMsgBuf = ByteBuffer.allocate(22).order(ByteOrder.nativeOrder())
partialMsgBuf.put(partialMessage.asByteBuffer)
RabitWorkerHandler.StructTrackerCommand.verify(partialMsgBuf) shouldBe true
fsm ! Tcp.Received(partialMessage)
fsm ! Tcp.Received(remainder)
trackerProbe.expectMsg(printCmd)
}
it should "handle spill-over Tcp data correctly between state transition" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val worldSize = 4
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
// send mock magic number
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
// send mock tracker command in fragments: the handler should be able to handle it.
val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder())
bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0'))
.putInt(5).put("start".getBytes())
// spilled-over data
.putInt(0).flip()
// send data with 4 extra bytes corresponding to the next state.
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
val linkMap = new LinkMap(worldSize)
val assignedRank = linkMap.assignRank(0)
trackerProbe.reply(assignedRank)
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
assignedRank.toByteBuffer(worldSize)
)))
// reading should be suspended upon transitioning to BuildingLinkMap
connProbe.expectMsg(Tcp.SuspendReading)
// state should transition with according state data changes.
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
connProbe.expectMsg(Tcp.ResumeReading)
// the handler should be able to handle spill-over data, and stash it until state transition.
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
}
}