[jvm-packages] Scala implementation of the Rabit tracker. (#1612)

* [jvm-packages] Scala implementation of the Rabit tracker.

A Scala implementation of RabitTracker that is interface-interchangable with the
Java implementation, ported from `tracker.py` in the
[dmlc-core project](https://github.com/dmlc/dmlc-core).

* [jvm-packages] Updated Akka dependency in pom.xml.

* Refactored the RabitTracker directory structure.

* Fixed premature stopping of connection handler.

Added a new finite state "AwaitingPortNumber" to explicitly wait for the
worker to send the port, and close the connection. Stopping the actor
prematurely sends a TCP RST to the worker, causing the worker to crash
on AssertionError.

* Added interface IRabitTracker so that user can switch implementations.

* Default timeout duration changes.

* Dependency for Akka tests.

* Removed the main function of RabitTracker.

* A skeleton for testing Akka-based Rabit tracker.

* waitFor() in RabitTracker no longer throws exceptions.

* Completed unit test for the 'start' command of Rabit tracker.

* Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.)

* Fixed the default timeout duration.

* Use Java container to avoid serialization issues due to intermediate wrappers.

* Added tests for Allreduce/model training using Scala Rabit tracker.

* Added spill-over unit test for the Scala Rabit tracker.

* Fixed a typo.

* Overhaul of RabitTracker interface per code review.

  - Removed methods start() waitFor() (no arguments) from IRabitTracker.
  - The timeout in start(timeout) is now worker connection timeout, as tcp
    socket binding timeout is less intuitive.
  - Dropped time unit from start(...) and waitFor(...) methods; the default
    time unit is millisecond.
  - Moved random port number generation into the RabitTrackerHandler.
  - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit.

* More code refactoring and comments.

* Unified timeout constants. Readable tracker status code.

* Add comments to indicate that allReduce is for tests only. Removed all other variants.

* Removed unused imports.

* Simplified signatures of training methods.

 - Moved TrackerConf into parameter map.
 - Changed GeneralParams so that TrackerConf becomes a standalone parameter.
 - Updated test cases accordingly.

* Changed monitoring strategies.

* Reverted monitoring changes.

* Update test case for Rabit AllReduce.

* Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers.

* More comprehensive test cases for exception handling and worker connection timeout.

* Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case.

* Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code.

* Reverted scalastyle-config changes.

* Visibility scope change. Interface tweaks.

* Use match pattern to handle tracker_conf parameter.

* Minor clarification in JNI code.

* Clearer intent in match pattern to suppress warnings.

* Removed Future from constructor. Block in start() and waitFor() instead.

* Revert inadvertent comment changes.

* Removed debugging information.

* Updated test cases that are a bit finicky.

* Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness.
This commit is contained in:
Xin Yin
2016-12-07 09:35:42 -05:00
committed by Nan Zhu
parent 7078c41dad
commit e7fbc8591f
19 changed files with 1910 additions and 25 deletions

View File

@@ -0,0 +1,224 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit
import java.nio.{ByteBuffer, ByteOrder}
import akka.actor.{ActorRef, ActorSystem}
import akka.io.Tcp
import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe}
import akka.util.ByteString
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._
import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpecLike, Matchers}
import scala.concurrent.Promise
object RabitTrackerConnectionHandlerTest {
def intSeqToByteString(seq: Seq[Int]): ByteString = {
val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder())
seq.foreach { i => buf.putInt(i) }
buf.flip()
ByteString.fromByteBuffer(buf)
}
}
@RunWith(classOf[JUnitRunner])
class RabitTrackerConnectionHandlerTest
extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest"))
with FlatSpecLike with Matchers with ImplicitSender {
import RabitTrackerConnectionHandlerTest._
val magic = intSeqToByteString(List(0xff99))
"RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val worldSize = 4
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
// send mock magic number
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
// send mock tracker command in fragments: the handler should be able to handle it.
val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
bufRank.putInt(0).putInt(worldSize).flip()
val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder())
bufJobId.putInt(1).put(Array[Byte]('0')).flip()
val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder())
bufCmd.putInt(5).put("start".getBytes()).flip()
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank))
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId))
// the state should not change for incomplete command data.
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
// send the last fragment, and expect message at tracker actor.
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
val linkMap = new LinkMap(worldSize)
val assignedRank = linkMap.assignRank(0)
trackerProbe.reply(assignedRank)
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
assignedRank.toByteBuffer(worldSize)
)))
// reading should be suspended upon transitioning to BuildingLinkMap
connProbe.expectMsg(Tcp.SuspendReading)
// state should transition with according state data changes.
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
connProbe.expectMsg(Tcp.ResumeReading)
// since the connection handler in test has rank 0, it will not have any nodes to connect to.
fsm ! Tcp.Received(intSeqToByteString(List(0)))
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
// return mock response to the connection handler
val awaitConnPromise = Promise[AwaitingConnections]()
awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef],
fsm.underlyingActor.getNeighboringWorkers.size
))
fsm ! awaitConnPromise.future
connProbe.expectMsg(Tcp.Write(
intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size))
))
connProbe.expectMsg(Tcp.SuspendReading)
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount
connProbe.expectMsg(Tcp.ResumeReading)
// send mock error count (0)
fsm ! Tcp.Received(intSeqToByteString(List(0)))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
connProbe.expectMsg(Tcp.ResumeReading)
// simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events.
fsm ! Tcp.PeerClosed
// state should not transition
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
fsm ! Tcp.Received(intSeqToByteString(List(32768)))
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
connProbe.expectMsg(Tcp.ResumeReading)
trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2))
val handlerStopProbe = TestProbe()
handlerStopProbe watch fsm
// simulate connections from other workers by mocking ReduceWaitCount commands
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0))
handlerStopProbe.expectTerminated(fsm)
// all done.
}
it should "forward print command to tracker" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
fsm ! Tcp.Received(printCmd.encode)
trackerProbe.expectMsg(printCmd)
}
it should "handle spill-over Tcp data correctly between state transition" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val worldSize = 4
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
// send mock magic number
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
// send mock tracker command in fragments: the handler should be able to handle it.
val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder())
bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0'))
.putInt(5).put("start".getBytes())
// spilled-over data
.putInt(0).flip()
// send data with 4 extra bytes corresponding to the next state.
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
val linkMap = new LinkMap(worldSize)
val assignedRank = linkMap.assignRank(0)
trackerProbe.reply(assignedRank)
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
assignedRank.toByteBuffer(worldSize)
)))
// reading should be suspended upon transitioning to BuildingLinkMap
connProbe.expectMsg(Tcp.SuspendReading)
// state should transition with according state data changes.
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
connProbe.expectMsg(Tcp.ResumeReading)
// the handler should be able to handle spill-over data, and stash it until state transition.
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
}
}