[jvm-packages] Fixed java.nio.BufferUnderFlow issue in Scala Rabit tracker. (#1993)
* [jvm-packages] Scala implementation of the Rabit tracker. A Scala implementation of RabitTracker that is interface-interchangable with the Java implementation, ported from `tracker.py` in the [dmlc-core project](https://github.com/dmlc/dmlc-core). * [jvm-packages] Updated Akka dependency in pom.xml. * Refactored the RabitTracker directory structure. * Fixed premature stopping of connection handler. Added a new finite state "AwaitingPortNumber" to explicitly wait for the worker to send the port, and close the connection. Stopping the actor prematurely sends a TCP RST to the worker, causing the worker to crash on AssertionError. * Added interface IRabitTracker so that user can switch implementations. * Default timeout duration changes. * Dependency for Akka tests. * Removed the main function of RabitTracker. * A skeleton for testing Akka-based Rabit tracker. * waitFor() in RabitTracker no longer throws exceptions. * Completed unit test for the 'start' command of Rabit tracker. * Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.) * Fixed the default timeout duration. * Use Java container to avoid serialization issues due to intermediate wrappers. * Added tests for Allreduce/model training using Scala Rabit tracker. * Added spill-over unit test for the Scala Rabit tracker. * Fixed a typo. * Overhaul of RabitTracker interface per code review. - Removed methods start() waitFor() (no arguments) from IRabitTracker. - The timeout in start(timeout) is now worker connection timeout, as tcp socket binding timeout is less intuitive. - Dropped time unit from start(...) and waitFor(...) methods; the default time unit is millisecond. - Moved random port number generation into the RabitTrackerHandler. - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit. * More code refactoring and comments. * Unified timeout constants. Readable tracker status code. * Add comments to indicate that allReduce is for tests only. Removed all other variants. * Removed unused imports. * Simplified signatures of training methods. - Moved TrackerConf into parameter map. - Changed GeneralParams so that TrackerConf becomes a standalone parameter. - Updated test cases accordingly. * Changed monitoring strategies. * Reverted monitoring changes. * Update test case for Rabit AllReduce. * Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers. * More comprehensive test cases for exception handling and worker connection timeout. * Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case. * Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code. * Reverted scalastyle-config changes. * Visibility scope change. Interface tweaks. * Use match pattern to handle tracker_conf parameter. * Minor clarification in JNI code. * Clearer intent in match pattern to suppress warnings. * Removed Future from constructor. Block in start() and waitFor() instead. * Revert inadvertent comment changes. * Removed debugging information. * Updated test cases that are a bit finicky. * Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness. * Fixed BufferUnderFlow bug in decoding tracker 'print' command. * Merge conflicts resolution.
This commit is contained in:
parent
2250b9b6d2
commit
4fb7fdb240
@ -84,12 +84,15 @@ private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: A
|
|||||||
def getNeighboringWorkers: Set[Int] = neighboringWorkers
|
def getNeighboringWorkers: Set[Int] = neighboringWorkers
|
||||||
|
|
||||||
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
|
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
|
||||||
val rank = buffer.getInt()
|
val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
|
||||||
val worldSize = buffer.getInt()
|
readBuffer.flip()
|
||||||
val jobId = buffer.getString
|
|
||||||
|
|
||||||
val command = buffer.getString
|
val rank = readBuffer.getInt()
|
||||||
command match {
|
val worldSize = readBuffer.getInt()
|
||||||
|
val jobId = readBuffer.getString
|
||||||
|
|
||||||
|
val command = readBuffer.getString
|
||||||
|
val trackerCommand = command match {
|
||||||
case "start" => WorkerStart(rank, worldSize, jobId)
|
case "start" => WorkerStart(rank, worldSize, jobId)
|
||||||
case "shutdown" =>
|
case "shutdown" =>
|
||||||
transient = true
|
transient = true
|
||||||
@ -99,8 +102,11 @@ private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: A
|
|||||||
WorkerRecover(rank, worldSize, jobId)
|
WorkerRecover(rank, worldSize, jobId)
|
||||||
case "print" =>
|
case "print" =>
|
||||||
transient = true
|
transient = true
|
||||||
WorkerTrackerPrint(rank, worldSize, jobId, buffer.getString)
|
WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stashSpillOver(readBuffer)
|
||||||
|
trackerCommand
|
||||||
}
|
}
|
||||||
|
|
||||||
startWith(AwaitingHandshake, DataStruct())
|
startWith(AwaitingHandshake, DataStruct())
|
||||||
@ -120,9 +126,14 @@ private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: A
|
|||||||
case Event(Tcp.Received(bytes), validator) =>
|
case Event(Tcp.Received(bytes), validator) =>
|
||||||
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
|
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
|
||||||
if (validator.verify(readBuffer)) {
|
if (validator.verify(readBuffer)) {
|
||||||
readBuffer.flip()
|
Try(decodeCommand(readBuffer)) match {
|
||||||
tracker ! decodeCommand(readBuffer)
|
case scala.util.Success(decodedCommand) =>
|
||||||
stashSpillOver(readBuffer)
|
tracker ! decodedCommand
|
||||||
|
case scala.util.Failure(th: java.nio.BufferOverflowException) =>
|
||||||
|
// BufferOverflowException would occur if the message to print has not arrived yet.
|
||||||
|
// Do nothing, wait for next Tcp.Received event
|
||||||
|
case scala.util.Failure(th: Throwable) => throw th
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stay
|
stay
|
||||||
|
|||||||
@ -172,6 +172,32 @@ class RabitTrackerConnectionHandlerTest
|
|||||||
trackerProbe.expectMsg(printCmd)
|
trackerProbe.expectMsg(printCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
it should "handle fragmented print command without throwing exception" in {
|
||||||
|
val trackerProbe = TestProbe()
|
||||||
|
val connProbe = TestProbe()
|
||||||
|
|
||||||
|
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
|
||||||
|
trackerProbe.ref, connProbe.ref))
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||||
|
|
||||||
|
fsm ! Tcp.Received(magic)
|
||||||
|
connProbe.expectMsg(Tcp.Write(magic))
|
||||||
|
|
||||||
|
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||||
|
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||||
|
// ResumeReading should be seen once state transitions
|
||||||
|
connProbe.expectMsg(Tcp.ResumeReading)
|
||||||
|
|
||||||
|
val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
|
||||||
|
// 4 + 4 + 4 + 5 = 17
|
||||||
|
val (partialMessage, remainder) = printCmd.encode.splitAt(17)
|
||||||
|
|
||||||
|
fsm ! Tcp.Received(partialMessage)
|
||||||
|
fsm ! Tcp.Received(remainder)
|
||||||
|
|
||||||
|
trackerProbe.expectMsg(printCmd)
|
||||||
|
}
|
||||||
|
|
||||||
it should "handle spill-over Tcp data correctly between state transition" in {
|
it should "handle spill-over Tcp data correctly between state transition" in {
|
||||||
val trackerProbe = TestProbe()
|
val trackerProbe = TestProbe()
|
||||||
val connProbe = TestProbe()
|
val connProbe = TestProbe()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user