[jvm-packages] fix the scalability issue of prediction (#4033)
This commit is contained in:
parent
15fe2f1e7c
commit
f368d0de2b
@ -335,25 +335,6 @@
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.jacoco</groupId>
|
||||
<artifactId>jacoco-maven-plugin</artifactId>
|
||||
<version>0.7.9</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>prepare-agent</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>report</id>
|
||||
<phase>test</phase>
|
||||
<goals>
|
||||
<goal>report</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
|
||||
@ -285,12 +285,12 @@ class XGBoostClassificationModel private[ml](
|
||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = dataset.sparkSession.sparkContext.appName
|
||||
|
||||
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
|
||||
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
|
||||
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cacheInfo = {
|
||||
@ -307,19 +307,27 @@ class XGBoostClassificationModel private[ml](
|
||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
||||
producePredictionItrs(bBooster, dm)
|
||||
Rabit.shutdown()
|
||||
produceResultIterator(rowItr1, rawPredictionItr, probabilityItr, predLeafItr,
|
||||
Iterator(rawPredictionItr, probabilityItr, predLeafItr,
|
||||
predContribItr)
|
||||
} finally {
|
||||
dm.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator[Row]()
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
|
||||
case (inputIterator, predictionItr) =>
|
||||
if (inputIterator.hasNext) {
|
||||
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
|
||||
predictionItr.next(), predictionItr.next())
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
|
||||
bBooster.unpersist(blocking = false)
|
||||
|
||||
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
|
||||
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
||||
}
|
||||
|
||||
private def produceResultIterator(
|
||||
|
||||
@ -257,13 +257,12 @@ class XGBoostRegressionModel private[ml] (
|
||||
|
||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = dataset.sparkSession.sparkContext.appName
|
||||
|
||||
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
|
||||
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
|
||||
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cacheInfo = {
|
||||
@ -273,7 +272,6 @@ class XGBoostRegressionModel private[ml] (
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
val dm = new DMatrix(
|
||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||
cacheInfo)
|
||||
@ -281,16 +279,25 @@ class XGBoostRegressionModel private[ml] (
|
||||
val Array(originalPredictionItr, predLeafItr, predContribItr) =
|
||||
producePredictionItrs(bBooster, dm)
|
||||
Rabit.shutdown()
|
||||
produceResultIterator(rowItr1, originalPredictionItr, predLeafItr, predContribItr)
|
||||
Iterator(originalPredictionItr, predLeafItr, predContribItr)
|
||||
} finally {
|
||||
dm.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator[Row]()
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
|
||||
case (inputIterator, predictionItr) =>
|
||||
if (inputIterator.hasNext) {
|
||||
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
|
||||
predictionItr.next())
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
bBooster.unpersist(blocking = false)
|
||||
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
|
||||
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
||||
}
|
||||
|
||||
private def produceResultIterator(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user