[jvm-packages] fix the scalability issue of prediction (#4033)
This commit is contained in:
parent
15fe2f1e7c
commit
f368d0de2b
@ -335,25 +335,6 @@
|
|||||||
</execution>
|
</execution>
|
||||||
</executions>
|
</executions>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
|
||||||
<groupId>org.jacoco</groupId>
|
|
||||||
<artifactId>jacoco-maven-plugin</artifactId>
|
|
||||||
<version>0.7.9</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<goals>
|
|
||||||
<goal>prepare-agent</goal>
|
|
||||||
</goals>
|
|
||||||
</execution>
|
|
||||||
<execution>
|
|
||||||
<id>report</id>
|
|
||||||
<phase>test</phase>
|
|
||||||
<goals>
|
|
||||||
<goal>report</goal>
|
|
||||||
</goals>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
</plugins>
|
||||||
</build>
|
</build>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
|||||||
@ -285,12 +285,12 @@ class XGBoostClassificationModel private[ml](
|
|||||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||||
val appName = dataset.sparkSession.sparkContext.appName
|
val appName = dataset.sparkSession.sparkContext.appName
|
||||||
|
|
||||||
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
|
||||||
|
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||||
if (rowIterator.hasNext) {
|
if (rowIterator.hasNext) {
|
||||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
|
||||||
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
|
|
||||||
$(featuresCol))).toList.iterator
|
$(featuresCol))).toList.iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val cacheInfo = {
|
val cacheInfo = {
|
||||||
@ -307,19 +307,27 @@ class XGBoostClassificationModel private[ml](
|
|||||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
||||||
producePredictionItrs(bBooster, dm)
|
producePredictionItrs(bBooster, dm)
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
produceResultIterator(rowItr1, rawPredictionItr, probabilityItr, predLeafItr,
|
Iterator(rawPredictionItr, probabilityItr, predLeafItr,
|
||||||
predContribItr)
|
predContribItr)
|
||||||
} finally {
|
} finally {
|
||||||
dm.delete()
|
dm.delete()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Iterator[Row]()
|
Iterator()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
|
||||||
|
case (inputIterator, predictionItr) =>
|
||||||
|
if (inputIterator.hasNext) {
|
||||||
|
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
|
||||||
|
predictionItr.next(), predictionItr.next())
|
||||||
|
} else {
|
||||||
|
Iterator()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bBooster.unpersist(blocking = false)
|
bBooster.unpersist(blocking = false)
|
||||||
|
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
||||||
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private def produceResultIterator(
|
private def produceResultIterator(
|
||||||
|
|||||||
@ -257,13 +257,12 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
|
|
||||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||||
val appName = dataset.sparkSession.sparkContext.appName
|
val appName = dataset.sparkSession.sparkContext.appName
|
||||||
|
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
|
||||||
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||||
if (rowIterator.hasNext) {
|
if (rowIterator.hasNext) {
|
||||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
|
||||||
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
|
|
||||||
$(featuresCol))).toList.iterator
|
$(featuresCol))).toList.iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val cacheInfo = {
|
val cacheInfo = {
|
||||||
@ -273,7 +272,6 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
null
|
null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val dm = new DMatrix(
|
val dm = new DMatrix(
|
||||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||||
cacheInfo)
|
cacheInfo)
|
||||||
@ -281,16 +279,25 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
val Array(originalPredictionItr, predLeafItr, predContribItr) =
|
val Array(originalPredictionItr, predLeafItr, predContribItr) =
|
||||||
producePredictionItrs(bBooster, dm)
|
producePredictionItrs(bBooster, dm)
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
produceResultIterator(rowItr1, originalPredictionItr, predLeafItr, predContribItr)
|
Iterator(originalPredictionItr, predLeafItr, predContribItr)
|
||||||
} finally {
|
} finally {
|
||||||
dm.delete()
|
dm.delete()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Iterator[Row]()
|
Iterator()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
|
||||||
|
case (inputIterator, predictionItr) =>
|
||||||
|
if (inputIterator.hasNext) {
|
||||||
|
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
|
||||||
|
predictionItr.next())
|
||||||
|
} else {
|
||||||
|
Iterator()
|
||||||
|
}
|
||||||
|
}
|
||||||
bBooster.unpersist(blocking = false)
|
bBooster.unpersist(blocking = false)
|
||||||
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
|
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
||||||
}
|
}
|
||||||
|
|
||||||
private def produceResultIterator(
|
private def produceResultIterator(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user