diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 9aa5e84dc..6dad15bc4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -20,8 +20,9 @@ import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.{TaskContext, SparkContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.rdd.RDD -import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} -import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} +import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix} +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} +import scala.collection.JavaConverters._ class XGBoostModel(_booster: Booster) extends Serializable { @@ -37,6 +38,8 @@ class XGBoostModel(_booster: Booster) extends Serializable { val appName = testSet.context.appName testSet.mapPartitions { testSamples => if (testSamples.hasNext) { + val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap + Rabit.init(rabitEnv.asJava) val cacheFileName = { if (useExternalCache) { s"$appName-dtest_cache-${TaskContext.getPartitionId()}" @@ -45,7 +48,9 @@ class XGBoostModel(_booster: Booster) extends Serializable { } } val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName)) - Iterator(broadcastBooster.value.predict(dMatrix)) + val res = broadcastBooster.value.predict(dMatrix) + Rabit.shutdown() + Iterator(res) } else { Iterator() }