From a187ed6c8f3aa40b47d5be80667cbbe6a6fd563d Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Sun, 7 Jan 2018 08:34:18 -0800 Subject: [PATCH] [jvm-packages] tiny fix for empty partition in predict (#3014) * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * tiny fix for empty partition in predict * further fix --- .../scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 4b77eec4b..8a0d6d2e6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -169,12 +169,12 @@ abstract class XGBoostModel(protected var _booster: Booster) def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = { val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => - val sampleArray = testSamples.toList - val numRows = sampleArray.size - val numColumns = sampleArray.head.size + val sampleArray = testSamples.toArray + val numRows = sampleArray.length if (numRows == 0) { Iterator() } else { + val numColumns = sampleArray.head.size val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) Rabit.init(rabitEnv.asJava) // translate to required format