fix bug: doing rabit call after finalize in spark prediction phase (#1420)
This commit is contained in:
parent
328e8e4c69
commit
a8adf16228
@ -20,8 +20,9 @@ import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
import org.apache.spark.{TaskContext, SparkContext}
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
class XGBoostModel(_booster: Booster) extends Serializable {
|
||||
|
||||
@ -37,6 +38,8 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
||||
val appName = testSet.context.appName
|
||||
testSet.mapPartitions { testSamples =>
|
||||
if (testSamples.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val cacheFileName = {
|
||||
if (useExternalCache) {
|
||||
s"$appName-dtest_cache-${TaskContext.getPartitionId()}"
|
||||
@ -45,7 +48,9 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
||||
}
|
||||
}
|
||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
|
||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||
val res = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(res)
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user