fix bug: doing rabit call after finalize in spark prediction phase (#1420)

This commit is contained in:
Fangzhou 2016-07-29 12:11:20 +08:00 committed by Yuan (Terry) Tang
parent 328e8e4c69
commit a8adf16228

View File

@ -20,8 +20,9 @@ import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.{TaskContext, SparkContext} import org.apache.spark.{TaskContext, SparkContext}
import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import scala.collection.JavaConverters._
class XGBoostModel(_booster: Booster) extends Serializable { class XGBoostModel(_booster: Booster) extends Serializable {
@ -37,6 +38,8 @@ class XGBoostModel(_booster: Booster) extends Serializable {
val appName = testSet.context.appName val appName = testSet.context.appName
testSet.mapPartitions { testSamples => testSet.mapPartitions { testSamples =>
if (testSamples.hasNext) { if (testSamples.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val cacheFileName = { val cacheFileName = {
if (useExternalCache) { if (useExternalCache) {
s"$appName-dtest_cache-${TaskContext.getPartitionId()}" s"$appName-dtest_cache-${TaskContext.getPartitionId()}"
@ -45,7 +48,9 @@ class XGBoostModel(_booster: Booster) extends Serializable {
} }
} }
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName)) val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
Iterator(broadcastBooster.value.predict(dMatrix)) val res = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(res)
} else { } else {
Iterator() Iterator()
} }