From a92093388de21b1156e0c53a31bf4e92c9120848 Mon Sep 17 00:00:00 2001 From: hlsc Date: Fri, 3 Mar 2017 08:46:57 +0800 Subject: [PATCH] [jvm-packages] fix bug doing rabit call after finalize (#2079) [jvm-packages]fix bug doing rabit call after finalize --- .../ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 8f934229c..650eef2c8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -57,9 +57,13 @@ abstract class XGBoostModel(protected var _booster: Booster) import DataUtils._ val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => + val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) + Rabit.init(rabitEnv.asJava) if (testSamples.hasNext) { val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) - Iterator(broadcastBooster.value.predictLeaf(dMatrix)) + val res = broadcastBooster.value.predictLeaf(dMatrix) + Rabit.shutdown() + Iterator(res) } else { Iterator() } @@ -145,8 +149,9 @@ abstract class XGBoostModel(protected var _booster: Booster) flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat } val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) + val res = broadcastBooster.value.predict(dMatrix) Rabit.shutdown() - Iterator(broadcastBooster.value.predict(dMatrix)) + Iterator(res) } } }