[jvm-packages] fix bug doing rabit call after finalize (#2079)

[jvm-packages]fix bug doing rabit call after finalize
This commit is contained in:
hlsc 2017-03-03 08:46:57 +08:00 committed by Nan Zhu
parent fd19b7a188
commit a92093388d

View File

@ -57,9 +57,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
if (testSamples.hasNext) {
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
Iterator(broadcastBooster.value.predictLeaf(dMatrix))
val res = broadcastBooster.value.predictLeaf(dMatrix)
Rabit.shutdown()
Iterator(res)
} else {
Iterator()
}
@ -145,8 +149,9 @@ abstract class XGBoostModel(protected var _booster: Booster)
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
}
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
val res = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(broadcastBooster.value.predict(dMatrix))
Iterator(res)
}
}
}