From 5b54b9437cf2f329a6aedcab9377faf769472f57 Mon Sep 17 00:00:00 2001 From: Xin Yin Date: Sun, 5 Mar 2017 15:40:59 -0600 Subject: [PATCH] Fixed Exception handling for fragmented Rabit 'print' tracker command. Fixed unit test. (#2081) --- .../scala/rabit/handler/RabitWorkerHandler.scala | 4 ++-- .../rabit/RabitTrackerConnectionHandlerTest.scala | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala index 963799884..234c4d25a 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala @@ -129,8 +129,8 @@ private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: A Try(decodeCommand(readBuffer)) match { case scala.util.Success(decodedCommand) => tracker ! decodedCommand - case scala.util.Failure(th: java.nio.BufferOverflowException) => - // BufferOverflowException would occur if the message to print has not arrived yet. + case scala.util.Failure(th: java.nio.BufferUnderflowException) => + // BufferUnderflowException would occur if the message to print has not arrived yet. // Do nothing, wait for next Tcp.Received event case scala.util.Failure(th: Throwable) => throw th } diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala index 42994baca..cd9016812 100644 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala @@ -188,9 +188,14 @@ class RabitTrackerConnectionHandlerTest // ResumeReading should be seen once state transitions connProbe.expectMsg(Tcp.ResumeReading) - val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!") - // 4 + 4 + 4 + 5 = 17 - val (partialMessage, remainder) = printCmd.encode.splitAt(17) + val printCmd = WorkerTrackerPrint(0, 4, "0", "fragmented!") + // 4 (rank: Int) + 4 (worldSize: Int) + (4+1) (jobId: String) + (4+5) (command: String) = 22 + val (partialMessage, remainder) = printCmd.encode.splitAt(22) + + // make sure that the partialMessage in itself is a valid command + val partialMsgBuf = ByteBuffer.allocate(22).order(ByteOrder.nativeOrder()) + partialMsgBuf.put(partialMessage.asByteBuffer) + RabitWorkerHandler.StructTrackerCommand.verify(partialMsgBuf) shouldBe true fsm ! Tcp.Received(partialMessage) fsm ! Tcp.Received(remainder)