[jvm-packages] Fixed java.nio.BufferUnderFlow issue in Scala Rabit tracker. (#1993)

* [jvm-packages] Scala implementation of the Rabit tracker.

A Scala implementation of RabitTracker that is interface-interchangable with the
Java implementation, ported from `tracker.py` in the
[dmlc-core project](https://github.com/dmlc/dmlc-core).

* [jvm-packages] Updated Akka dependency in pom.xml.

* Refactored the RabitTracker directory structure.

* Fixed premature stopping of connection handler.

Added a new finite state "AwaitingPortNumber" to explicitly wait for the
worker to send the port, and close the connection. Stopping the actor
prematurely sends a TCP RST to the worker, causing the worker to crash
on AssertionError.

* Added interface IRabitTracker so that user can switch implementations.

* Default timeout duration changes.

* Dependency for Akka tests.

* Removed the main function of RabitTracker.

* A skeleton for testing Akka-based Rabit tracker.

* waitFor() in RabitTracker no longer throws exceptions.

* Completed unit test for the 'start' command of Rabit tracker.

* Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.)

* Fixed the default timeout duration.

* Use Java container to avoid serialization issues due to intermediate wrappers.

* Added tests for Allreduce/model training using Scala Rabit tracker.

* Added spill-over unit test for the Scala Rabit tracker.

* Fixed a typo.

* Overhaul of RabitTracker interface per code review.

  - Removed methods start() waitFor() (no arguments) from IRabitTracker.
  - The timeout in start(timeout) is now worker connection timeout, as tcp
    socket binding timeout is less intuitive.
  - Dropped time unit from start(...) and waitFor(...) methods; the default
    time unit is millisecond.
  - Moved random port number generation into the RabitTrackerHandler.
  - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit.

* More code refactoring and comments.

* Unified timeout constants. Readable tracker status code.

* Add comments to indicate that allReduce is for tests only. Removed all other variants.

* Removed unused imports.

* Simplified signatures of training methods.

 - Moved TrackerConf into parameter map.
 - Changed GeneralParams so that TrackerConf becomes a standalone parameter.
 - Updated test cases accordingly.

* Changed monitoring strategies.

* Reverted monitoring changes.

* Update test case for Rabit AllReduce.

* Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers.

* More comprehensive test cases for exception handling and worker connection timeout.

* Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case.

* Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code.

* Reverted scalastyle-config changes.

* Visibility scope change. Interface tweaks.

* Use match pattern to handle tracker_conf parameter.

* Minor clarification in JNI code.

* Clearer intent in match pattern to suppress warnings.

* Removed Future from constructor. Block in start() and waitFor() instead.

* Revert inadvertent comment changes.

* Removed debugging information.

* Updated test cases that are a bit finicky.

* Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness.

* Fixed BufferUnderFlow bug in decoding tracker 'print' command.

* Merge conflicts resolution.
This commit is contained in:
Xin Yin 2017-02-04 13:20:39 -05:00 committed by Nan Zhu
parent 2250b9b6d2
commit 4fb7fdb240
2 changed files with 46 additions and 9 deletions

View File

@ -84,12 +84,15 @@ private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: A
def getNeighboringWorkers: Set[Int] = neighboringWorkers def getNeighboringWorkers: Set[Int] = neighboringWorkers
def decodeCommand(buffer: ByteBuffer): TrackerCommand = { def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
val rank = buffer.getInt() val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
val worldSize = buffer.getInt() readBuffer.flip()
val jobId = buffer.getString
val command = buffer.getString val rank = readBuffer.getInt()
command match { val worldSize = readBuffer.getInt()
val jobId = readBuffer.getString
val command = readBuffer.getString
val trackerCommand = command match {
case "start" => WorkerStart(rank, worldSize, jobId) case "start" => WorkerStart(rank, worldSize, jobId)
case "shutdown" => case "shutdown" =>
transient = true transient = true
@ -99,8 +102,11 @@ private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: A
WorkerRecover(rank, worldSize, jobId) WorkerRecover(rank, worldSize, jobId)
case "print" => case "print" =>
transient = true transient = true
WorkerTrackerPrint(rank, worldSize, jobId, buffer.getString) WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString)
} }
stashSpillOver(readBuffer)
trackerCommand
} }
startWith(AwaitingHandshake, DataStruct()) startWith(AwaitingHandshake, DataStruct())
@ -120,9 +126,14 @@ private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: A
case Event(Tcp.Received(bytes), validator) => case Event(Tcp.Received(bytes), validator) =>
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) } bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
if (validator.verify(readBuffer)) { if (validator.verify(readBuffer)) {
readBuffer.flip() Try(decodeCommand(readBuffer)) match {
tracker ! decodeCommand(readBuffer) case scala.util.Success(decodedCommand) =>
stashSpillOver(readBuffer) tracker ! decodedCommand
case scala.util.Failure(th: java.nio.BufferOverflowException) =>
// BufferOverflowException would occur if the message to print has not arrived yet.
// Do nothing, wait for next Tcp.Received event
case scala.util.Failure(th: Throwable) => throw th
}
} }
stay stay

View File

@ -172,6 +172,32 @@ class RabitTrackerConnectionHandlerTest
trackerProbe.expectMsg(printCmd) trackerProbe.expectMsg(printCmd)
} }
it should "handle fragmented print command without throwing exception" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
// 4 + 4 + 4 + 5 = 17
val (partialMessage, remainder) = printCmd.encode.splitAt(17)
fsm ! Tcp.Received(partialMessage)
fsm ! Tcp.Received(remainder)
trackerProbe.expectMsg(printCmd)
}
it should "handle spill-over Tcp data correctly between state transition" in { it should "handle spill-over Tcp data correctly between state transition" in {
val trackerProbe = TestProbe() val trackerProbe = TestProbe()
val connProbe = TestProbe() val connProbe = TestProbe()