diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 8f934229c..650eef2c8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -57,9 +57,13 @@ abstract class XGBoostModel(protected var _booster: Booster) import DataUtils._ val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => + val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) + Rabit.init(rabitEnv.asJava) if (testSamples.hasNext) { val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) - Iterator(broadcastBooster.value.predictLeaf(dMatrix)) + val res = broadcastBooster.value.predictLeaf(dMatrix) + Rabit.shutdown() + Iterator(res) } else { Iterator() } @@ -145,8 +149,9 @@ abstract class XGBoostModel(protected var _booster: Booster) flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat } val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) + val res = broadcastBooster.value.predict(dMatrix) Rabit.shutdown() - Iterator(broadcastBooster.value.predict(dMatrix)) + Iterator(res) } } }