[jvm-packages] fix the scalability issue of prediction (#4033)

This commit is contained in:
Nan Zhu 2018-12-29 20:46:30 -08:00 committed by GitHub
parent 15fe2f1e7c
commit f368d0de2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 34 deletions

View File

@ -335,25 +335,6 @@
</execution> </execution>
</executions> </executions>
</plugin> </plugin>
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>0.7.9</version>
<executions>
<execution>
<goals>
<goal>prepare-agent</goal>
</goals>
</execution>
<execution>
<id>report</id>
<phase>test</phase>
<goals>
<goal>report</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins> </plugins>
</build> </build>
<dependencies> <dependencies>

View File

@ -285,12 +285,12 @@ class XGBoostClassificationModel private[ml](
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName val appName = dataset.sparkSession.sparkContext.appName
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) { if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate val featuresIterator = rowIterator.map(row => row.getAs[Vector](
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator $(featuresCol))).toList.iterator
import DataUtils._ import DataUtils._
val cacheInfo = { val cacheInfo = {
@ -307,19 +307,27 @@ class XGBoostClassificationModel private[ml](
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm) producePredictionItrs(bBooster, dm)
Rabit.shutdown() Rabit.shutdown()
produceResultIterator(rowItr1, rawPredictionItr, probabilityItr, predLeafItr, Iterator(rawPredictionItr, probabilityItr, predLeafItr,
predContribItr) predContribItr)
} finally { } finally {
dm.delete() dm.delete()
} }
} else { } else {
Iterator[Row]() Iterator()
} }
} }
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
case (inputIterator, predictionItr) =>
if (inputIterator.hasNext) {
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
predictionItr.next(), predictionItr.next())
} else {
Iterator()
}
}
bBooster.unpersist(blocking = false) bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
} }
private def produceResultIterator( private def produceResultIterator(

View File

@ -257,13 +257,12 @@ class XGBoostRegressionModel private[ml] (
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName val appName = dataset.sparkSession.sparkContext.appName
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) { if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate val featuresIterator = rowIterator.map(row => row.getAs[Vector](
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator $(featuresCol))).toList.iterator
import DataUtils._ import DataUtils._
val cacheInfo = { val cacheInfo = {
@ -273,7 +272,6 @@ class XGBoostRegressionModel private[ml] (
null null
} }
} }
val dm = new DMatrix( val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)), XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo) cacheInfo)
@ -281,16 +279,25 @@ class XGBoostRegressionModel private[ml] (
val Array(originalPredictionItr, predLeafItr, predContribItr) = val Array(originalPredictionItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm) producePredictionItrs(bBooster, dm)
Rabit.shutdown() Rabit.shutdown()
produceResultIterator(rowItr1, originalPredictionItr, predLeafItr, predContribItr) Iterator(originalPredictionItr, predLeafItr, predContribItr)
} finally { } finally {
dm.delete() dm.delete()
} }
} else { } else {
Iterator[Row]() Iterator()
} }
} }
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
case (inputIterator, predictionItr) =>
if (inputIterator.hasNext) {
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
predictionItr.next())
} else {
Iterator()
}
}
bBooster.unpersist(blocking = false) bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema)) dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
} }
private def produceResultIterator( private def produceResultIterator(