fix bug: doing rabit call after finalize in spark prediction phase (#1420)
This commit is contained in:
parent
328e8e4c69
commit
a8adf16228
@ -20,8 +20,9 @@ import org.apache.hadoop.fs.{Path, FileSystem}
|
|||||||
import org.apache.spark.{TaskContext, SparkContext}
|
import org.apache.spark.{TaskContext, SparkContext}
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
|
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
class XGBoostModel(_booster: Booster) extends Serializable {
|
class XGBoostModel(_booster: Booster) extends Serializable {
|
||||||
|
|
||||||
@ -37,6 +38,8 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
val appName = testSet.context.appName
|
val appName = testSet.context.appName
|
||||||
testSet.mapPartitions { testSamples =>
|
testSet.mapPartitions { testSamples =>
|
||||||
if (testSamples.hasNext) {
|
if (testSamples.hasNext) {
|
||||||
|
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
|
Rabit.init(rabitEnv.asJava)
|
||||||
val cacheFileName = {
|
val cacheFileName = {
|
||||||
if (useExternalCache) {
|
if (useExternalCache) {
|
||||||
s"$appName-dtest_cache-${TaskContext.getPartitionId()}"
|
s"$appName-dtest_cache-${TaskContext.getPartitionId()}"
|
||||||
@ -45,7 +48,9 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
|
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
|
||||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
val res = broadcastBooster.value.predict(dMatrix)
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(res)
|
||||||
} else {
|
} else {
|
||||||
Iterator()
|
Iterator()
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user