Rework transform (#7440)

extract the common part of transform code from XGBoostClassifier
and XGBoostRegressor
This commit is contained in:
Bobby Wang 2021-11-18 15:48:57 +08:00 committed by GitHub
parent 2adf222fb2
commit 7cfb310eb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 147 additions and 186 deletions

View File

@ -18,23 +18,30 @@ package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files import java.nio.file.Files
import scala.collection.{AbstractIterator, mutable} import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel._originalPredictionCol
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.functions.{col, lit}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext import org.apache.spark.TaskContext
import org.apache.spark.ml.Estimator import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
/** /**
* PreXGBoost converts Dataset[_] to RDD[[Watches]] * PreXGBoost serves preparing data before training and transform
*/ */
object PreXGBoost { object PreXGBoost {
@ -117,6 +124,131 @@ object PreXGBoost {
} }
/**
* Transform Dataset
*
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
* @param dataset the input Dataset to transform
* @return the transformed DataFrame
*/
def transformDataFrame(model: Model[_], dataset: Dataset[_]): DataFrame = {
/** get the necessary parameters */
val (booster, inferBatchSize, featuresCol, useExternalMemory, missing, allowNonZeroForMissing,
predictFunc, schema) =
model match {
case m: XGBoostClassificationModel =>
// predict and turn to Row
val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
m.producePredictionItrs(broadcastBooster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr,
predLeafItr, predContribItr)
}
// prepare the final Schema
var schema = StructType(dataset.schema.fields ++
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)) ++
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
if (m.isDefined(m.leafPredictionCol)) {
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (m.isDefined(m.contribPredictionCol)) {
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing,
m.getAllowNonZeroForMissingValue, predictFunc, schema)
case m: XGBoostRegressionModel =>
// predict and turn to Row
val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, predLeafItr, predContribItr) =
m.producePredictionItrs(broadcastBooster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr, predContribItr)
}
// prepare the final Schema
var schema = StructType(dataset.schema.fields ++
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
if (m.isDefined(m.leafPredictionCol)) {
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (m.isDefined(m.contribPredictionCol)) {
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing,
m.getAllowNonZeroForMissingValue, predictFunc, schema)
}
val bBooster = dataset.sparkSession.sparkContext.broadcast(booster)
val appName = dataset.sparkSession.sparkContext.appName
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
new AbstractIterator[Row] {
private var batchCnt = 0
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
}
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
import DataUtils._
val cacheInfo = {
if (useExternalMemory) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
} else {
null
}
}
val dm = new DMatrix(
processMissingValues(features.map(_.asXGB), missing, allowNonZeroForMissing),
cacheInfo)
try {
predictFunc(bBooster, dm, batchRow.iterator)
} finally {
batchCnt += 1
dm.delete()
}
}
override def hasNext: Boolean = batchIterImpl.hasNext
override def next(): Row = {
val ret = batchIterImpl.next()
if (!batchIterImpl.hasNext) {
Rabit.shutdown()
}
ret
}
}
}
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(resultRDD, schema)
}
/** /**
* Converting the RDD[XGBLabeledPoint] to the function to build RDD[Watches] * Converting the RDD[XGBLabeledPoint] to the function to build RDD[Watches]
* *

View File

@ -16,25 +16,19 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.spark.params._ import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.classification._ import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg._ import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.json4s.DefaultFormats import org.json4s.DefaultFormats
import scala.collection.JavaConverters._ import scala.collection.{Iterator, mutable}
import scala.collection.{AbstractIterator, Iterator, mutable}
class XGBoostClassifier ( class XGBoostClassifier (
override val uid: String, override val uid: String,
@ -277,76 +271,7 @@ class XGBoostClassificationModel private[ml](
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'") throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
} }
// Generate raw prediction and probability prediction. private[spark] def produceResultIterator(
private def transformInternal(dataset: Dataset[_]): DataFrame = {
val schema = StructType(dataset.schema.fields ++
Seq(StructField(name = _rawPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)) ++
Seq(StructField(name = _probabilityCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
new AbstractIterator[Row] {
private var batchCnt = 0
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
}
val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol)))
import DataUtils._
val cacheInfo = {
if ($(useExternalMemory)) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
} else {
null
}
}
val dm = new DMatrix(
processMissingValues(
features.map(_.asXGB),
$(missing),
$(allowNonZeroForMissing)
),
cacheInfo)
try {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm)
produceResultIterator(batchRow.iterator,
rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
} finally {
batchCnt += 1
dm.delete()
}
}
override def hasNext: Boolean = batchIterImpl.hasNext
override def next(): Row = {
val ret = batchIterImpl.next()
if (!batchIterImpl.hasNext) {
Rabit.shutdown()
}
ret
}
}
}
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
}
private def produceResultIterator(
originalRowItr: Iterator[Row], originalRowItr: Iterator[Row],
rawPredictionItr: Iterator[Row], rawPredictionItr: Iterator[Row],
probabilityItr: Iterator[Row], probabilityItr: Iterator[Row],
@ -381,20 +306,7 @@ class XGBoostClassificationModel private[ml](
} }
} }
private def generateResultSchema(fixedSchema: StructType): StructType = { private[spark] def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
var resultSchema = fixedSchema
if (isDefined(leafPredictionCol)) {
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (isDefined(contribPredictionCol)) {
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
resultSchema
}
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
Array[Iterator[Row]] = { Array[Iterator[Row]] = {
val rawPredictionItr = { val rawPredictionItr = {
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)). broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
@ -431,7 +343,7 @@ class XGBoostClassificationModel private[ml](
// Output selected columns only. // Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation. // This is a bit complicated since it tries to avoid repeated computation.
var outputData = transformInternal(dataset) var outputData = PreXGBoost.transformDataFrame(this, dataset)
var numColsOutput = 0 var numColsOutput = 0
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] => val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
@ -492,8 +404,8 @@ class XGBoostClassificationModel private[ml](
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] { object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
private val _rawPredictionCol = "_rawPrediction" private[spark] val _rawPredictionCol = "_rawPrediction"
private val _probabilityCol = "_probability" private[spark] val _probabilityCol = "_probability"
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader

View File

@ -16,25 +16,19 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.{AbstractIterator, Iterator, mutable} import scala.collection.{Iterator, mutable}
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.TaskContext
import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.ml._ import org.apache.spark.ml._
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.json4s.DefaultFormats import org.json4s.DefaultFormats
import org.apache.spark.broadcast.Broadcast import org.apache.spark.broadcast.Broadcast
@ -257,71 +251,7 @@ class XGBoostRegressionModel private[ml] (
_booster.predict(data = dm)(0)(0) _booster.predict(data = dm)(0)(0)
} }
private def transformInternal(dataset: Dataset[_]): DataFrame = { private[spark] def produceResultIterator(
val schema = StructType(dataset.schema.fields ++
Seq(StructField(name = _originalPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
new AbstractIterator[Row] {
private var batchCnt = 0
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
}
val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol)))
import DataUtils._
val cacheInfo = {
if ($(useExternalMemory)) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
} else {
null
}
}
val dm = new DMatrix(
processMissingValues(
features.map(_.asXGB),
$(missing),
$(allowNonZeroForMissing)
),
cacheInfo)
try {
val Array(rawPredictionItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm)
produceResultIterator(batchRow.iterator, rawPredictionItr, predLeafItr, predContribItr)
} finally {
batchCnt += 1
dm.delete()
}
}
override def hasNext: Boolean = batchIterImpl.hasNext
override def next(): Row = {
val ret = batchIterImpl.next()
if (!batchIterImpl.hasNext) {
Rabit.shutdown()
}
ret
}
}
}
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
}
private def produceResultIterator(
originalRowItr: Iterator[Row], originalRowItr: Iterator[Row],
predictionItr: Iterator[Row], predictionItr: Iterator[Row],
predLeafItr: Iterator[Row], predLeafItr: Iterator[Row],
@ -353,20 +283,7 @@ class XGBoostRegressionModel private[ml] (
} }
} }
private def generateResultSchema(fixedSchema: StructType): StructType = { private[spark] def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
var resultSchema = fixedSchema
if (isDefined(leafPredictionCol)) {
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (isDefined(contribPredictionCol)) {
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
resultSchema
}
private def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
Array[Iterator[Row]] = { Array[Iterator[Row]] = {
val originalPredictionItr = { val originalPredictionItr = {
booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
@ -394,7 +311,7 @@ class XGBoostRegressionModel private[ml] (
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
// Output selected columns only. // Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation. // This is a bit complicated since it tries to avoid repeated computation.
var outputData = transformInternal(dataset) var outputData = PreXGBoost.transformDataFrame(this, dataset)
var numColsOutput = 0 var numColsOutput = 0
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) => val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
@ -425,7 +342,7 @@ class XGBoostRegressionModel private[ml] (
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] { object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
private val _originalPredictionCol = "_originalPrediction" private[spark] val _originalPredictionCol = "_originalPrediction"
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader