[jvm-packages] fix the scalability issue of prediction (#4033)

This commit is contained in:
Nan Zhu 2018-12-29 20:46:30 -08:00 committed by GitHub
parent 15fe2f1e7c
commit f368d0de2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 34 deletions

View File

@ -335,25 +335,6 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>0.7.9</version>
<executions>
<execution>
<goals>
<goal>prepare-agent</goal>
</goals>
</execution>
<execution>
<id>report</id>
<phase>test</phase>
<goals>
<goal>report</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>

View File

@ -285,12 +285,12 @@ class XGBoostClassificationModel private[ml](
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator
import DataUtils._
val cacheInfo = {
@ -307,19 +307,27 @@ class XGBoostClassificationModel private[ml](
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm)
Rabit.shutdown()
produceResultIterator(rowItr1, rawPredictionItr, probabilityItr, predLeafItr,
Iterator(rawPredictionItr, probabilityItr, predLeafItr,
predContribItr)
} finally {
dm.delete()
}
} else {
Iterator[Row]()
Iterator()
}
}
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
case (inputIterator, predictionItr) =>
if (inputIterator.hasNext) {
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
predictionItr.next(), predictionItr.next())
} else {
Iterator()
}
}
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
}
private def produceResultIterator(

View File

@ -257,13 +257,12 @@ class XGBoostRegressionModel private[ml] (
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator
import DataUtils._
val cacheInfo = {
@ -273,7 +272,6 @@ class XGBoostRegressionModel private[ml] (
null
}
}
val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
@ -281,16 +279,25 @@ class XGBoostRegressionModel private[ml] (
val Array(originalPredictionItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm)
Rabit.shutdown()
produceResultIterator(rowItr1, originalPredictionItr, predLeafItr, predContribItr)
Iterator(originalPredictionItr, predLeafItr, predContribItr)
} finally {
dm.delete()
}
} else {
Iterator[Row]()
Iterator()
}
}
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
case (inputIterator, predictionItr) =>
if (inputIterator.hasNext) {
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
predictionItr.next())
} else {
Iterator()
}
}
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
}
private def produceResultIterator(