Rework transform (#7440)
extract the common part of transform code from XGBoostClassifier and XGBoostRegressor
This commit is contained in:
parent
2adf222fb2
commit
7cfb310eb4
@ -18,23 +18,30 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
|
|
||||||
import scala.collection.{AbstractIterator, mutable}
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel._originalPredictionCol
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.{col, lit}
|
import org.apache.spark.sql.functions.{col, lit}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.ml.Estimator
|
import org.apache.spark.broadcast.Broadcast
|
||||||
|
import org.apache.spark.ml.{Estimator, Model}
|
||||||
|
import org.apache.spark.ml.linalg.Vector
|
||||||
|
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PreXGBoost converts Dataset[_] to RDD[[Watches]]
|
* PreXGBoost serves preparing data before training and transform
|
||||||
*/
|
*/
|
||||||
object PreXGBoost {
|
object PreXGBoost {
|
||||||
|
|
||||||
@ -117,6 +124,131 @@ object PreXGBoost {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transform Dataset
|
||||||
|
*
|
||||||
|
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
||||||
|
* @param dataset the input Dataset to transform
|
||||||
|
* @return the transformed DataFrame
|
||||||
|
*/
|
||||||
|
def transformDataFrame(model: Model[_], dataset: Dataset[_]): DataFrame = {
|
||||||
|
|
||||||
|
/** get the necessary parameters */
|
||||||
|
val (booster, inferBatchSize, featuresCol, useExternalMemory, missing, allowNonZeroForMissing,
|
||||||
|
predictFunc, schema) =
|
||||||
|
model match {
|
||||||
|
case m: XGBoostClassificationModel =>
|
||||||
|
|
||||||
|
// predict and turn to Row
|
||||||
|
val predictFunc =
|
||||||
|
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
||||||
|
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
||||||
|
m.producePredictionItrs(broadcastBooster, dm)
|
||||||
|
m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr,
|
||||||
|
predLeafItr, predContribItr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the final Schema
|
||||||
|
var schema = StructType(dataset.schema.fields ++
|
||||||
|
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false)) ++
|
||||||
|
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||||
|
|
||||||
|
if (m.isDefined(m.leafPredictionCol)) {
|
||||||
|
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
|
}
|
||||||
|
if (m.isDefined(m.contribPredictionCol)) {
|
||||||
|
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
|
}
|
||||||
|
|
||||||
|
(m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing,
|
||||||
|
m.getAllowNonZeroForMissingValue, predictFunc, schema)
|
||||||
|
|
||||||
|
case m: XGBoostRegressionModel =>
|
||||||
|
// predict and turn to Row
|
||||||
|
val predictFunc =
|
||||||
|
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
||||||
|
val Array(rawPredictionItr, predLeafItr, predContribItr) =
|
||||||
|
m.producePredictionItrs(broadcastBooster, dm)
|
||||||
|
m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr, predContribItr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the final Schema
|
||||||
|
var schema = StructType(dataset.schema.fields ++
|
||||||
|
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||||
|
|
||||||
|
if (m.isDefined(m.leafPredictionCol)) {
|
||||||
|
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
|
}
|
||||||
|
if (m.isDefined(m.contribPredictionCol)) {
|
||||||
|
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
|
}
|
||||||
|
|
||||||
|
(m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing,
|
||||||
|
m.getAllowNonZeroForMissingValue, predictFunc, schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
val bBooster = dataset.sparkSession.sparkContext.broadcast(booster)
|
||||||
|
val appName = dataset.sparkSession.sparkContext.appName
|
||||||
|
|
||||||
|
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||||
|
new AbstractIterator[Row] {
|
||||||
|
private var batchCnt = 0
|
||||||
|
|
||||||
|
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
|
||||||
|
if (batchCnt == 0) {
|
||||||
|
val rabitEnv = Array(
|
||||||
|
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
|
Rabit.init(rabitEnv.asJava)
|
||||||
|
}
|
||||||
|
|
||||||
|
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||||
|
|
||||||
|
import DataUtils._
|
||||||
|
val cacheInfo = {
|
||||||
|
if (useExternalMemory) {
|
||||||
|
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
|
||||||
|
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val dm = new DMatrix(
|
||||||
|
processMissingValues(features.map(_.asXGB), missing, allowNonZeroForMissing),
|
||||||
|
cacheInfo)
|
||||||
|
|
||||||
|
try {
|
||||||
|
predictFunc(bBooster, dm, batchRow.iterator)
|
||||||
|
} finally {
|
||||||
|
batchCnt += 1
|
||||||
|
dm.delete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def hasNext: Boolean = batchIterImpl.hasNext
|
||||||
|
|
||||||
|
override def next(): Row = {
|
||||||
|
val ret = batchIterImpl.next()
|
||||||
|
if (!batchIterImpl.hasNext) {
|
||||||
|
Rabit.shutdown()
|
||||||
|
}
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bBooster.unpersist(blocking = false)
|
||||||
|
dataset.sparkSession.createDataFrame(resultRDD, schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converting the RDD[XGBLabeledPoint] to the function to build RDD[Watches]
|
* Converting the RDD[XGBLabeledPoint] to the function to build RDD[Watches]
|
||||||
*
|
*
|
||||||
|
|||||||
@ -16,25 +16,19 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
import org.apache.spark.TaskContext
|
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
import org.apache.spark.ml.classification._
|
import org.apache.spark.ml.classification._
|
||||||
import org.apache.spark.ml.linalg._
|
import org.apache.spark.ml.linalg._
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
|
||||||
import org.json4s.DefaultFormats
|
import org.json4s.DefaultFormats
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.{Iterator, mutable}
|
||||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
|
||||||
|
|
||||||
class XGBoostClassifier (
|
class XGBoostClassifier (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
@ -277,76 +271,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate raw prediction and probability prediction.
|
private[spark] def produceResultIterator(
|
||||||
private def transformInternal(dataset: Dataset[_]): DataFrame = {
|
|
||||||
|
|
||||||
val schema = StructType(dataset.schema.fields ++
|
|
||||||
Seq(StructField(name = _rawPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)) ++
|
|
||||||
Seq(StructField(name = _probabilityCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
|
||||||
|
|
||||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
|
||||||
val appName = dataset.sparkSession.sparkContext.appName
|
|
||||||
|
|
||||||
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
|
||||||
new AbstractIterator[Row] {
|
|
||||||
private var batchCnt = 0
|
|
||||||
|
|
||||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
|
||||||
if (batchCnt == 0) {
|
|
||||||
val rabitEnv = Array(
|
|
||||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
|
||||||
Rabit.init(rabitEnv.asJava)
|
|
||||||
}
|
|
||||||
|
|
||||||
val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol)))
|
|
||||||
|
|
||||||
import DataUtils._
|
|
||||||
val cacheInfo = {
|
|
||||||
if ($(useExternalMemory)) {
|
|
||||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
|
|
||||||
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val dm = new DMatrix(
|
|
||||||
processMissingValues(
|
|
||||||
features.map(_.asXGB),
|
|
||||||
$(missing),
|
|
||||||
$(allowNonZeroForMissing)
|
|
||||||
),
|
|
||||||
cacheInfo)
|
|
||||||
try {
|
|
||||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
|
||||||
producePredictionItrs(bBooster, dm)
|
|
||||||
produceResultIterator(batchRow.iterator,
|
|
||||||
rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
|
|
||||||
} finally {
|
|
||||||
batchCnt += 1
|
|
||||||
dm.delete()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def hasNext: Boolean = batchIterImpl.hasNext
|
|
||||||
|
|
||||||
override def next(): Row = {
|
|
||||||
val ret = batchIterImpl.next()
|
|
||||||
if (!batchIterImpl.hasNext) {
|
|
||||||
Rabit.shutdown()
|
|
||||||
}
|
|
||||||
ret
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bBooster.unpersist(blocking = false)
|
|
||||||
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
|
||||||
}
|
|
||||||
|
|
||||||
private def produceResultIterator(
|
|
||||||
originalRowItr: Iterator[Row],
|
originalRowItr: Iterator[Row],
|
||||||
rawPredictionItr: Iterator[Row],
|
rawPredictionItr: Iterator[Row],
|
||||||
probabilityItr: Iterator[Row],
|
probabilityItr: Iterator[Row],
|
||||||
@ -381,20 +306,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def generateResultSchema(fixedSchema: StructType): StructType = {
|
private[spark] def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
|
||||||
var resultSchema = fixedSchema
|
|
||||||
if (isDefined(leafPredictionCol)) {
|
|
||||||
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
if (isDefined(contribPredictionCol)) {
|
|
||||||
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
resultSchema
|
|
||||||
}
|
|
||||||
|
|
||||||
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
|
|
||||||
Array[Iterator[Row]] = {
|
Array[Iterator[Row]] = {
|
||||||
val rawPredictionItr = {
|
val rawPredictionItr = {
|
||||||
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
|
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
|
||||||
@ -431,7 +343,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
|
|
||||||
// Output selected columns only.
|
// Output selected columns only.
|
||||||
// This is a bit complicated since it tries to avoid repeated computation.
|
// This is a bit complicated since it tries to avoid repeated computation.
|
||||||
var outputData = transformInternal(dataset)
|
var outputData = PreXGBoost.transformDataFrame(this, dataset)
|
||||||
var numColsOutput = 0
|
var numColsOutput = 0
|
||||||
|
|
||||||
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
|
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
|
||||||
@ -492,8 +404,8 @@ class XGBoostClassificationModel private[ml](
|
|||||||
|
|
||||||
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
|
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
|
||||||
|
|
||||||
private val _rawPredictionCol = "_rawPrediction"
|
private[spark] val _rawPredictionCol = "_rawPrediction"
|
||||||
private val _probabilityCol = "_probability"
|
private[spark] val _probabilityCol = "_probability"
|
||||||
|
|
||||||
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader
|
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader
|
||||||
|
|
||||||
|
|||||||
@ -16,25 +16,19 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
import scala.collection.{Iterator, mutable}
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.ml._
|
import org.apache.spark.ml._
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
|
||||||
import org.json4s.DefaultFormats
|
import org.json4s.DefaultFormats
|
||||||
|
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
@ -257,71 +251,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
_booster.predict(data = dm)(0)(0)
|
_booster.predict(data = dm)(0)(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def transformInternal(dataset: Dataset[_]): DataFrame = {
|
private[spark] def produceResultIterator(
|
||||||
|
|
||||||
val schema = StructType(dataset.schema.fields ++
|
|
||||||
Seq(StructField(name = _originalPredictionCol, dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
|
||||||
|
|
||||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
|
||||||
val appName = dataset.sparkSession.sparkContext.appName
|
|
||||||
|
|
||||||
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
|
||||||
new AbstractIterator[Row] {
|
|
||||||
private var batchCnt = 0
|
|
||||||
|
|
||||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
|
||||||
if (batchCnt == 0) {
|
|
||||||
val rabitEnv = Array(
|
|
||||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
|
||||||
Rabit.init(rabitEnv.asJava)
|
|
||||||
}
|
|
||||||
|
|
||||||
val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol)))
|
|
||||||
|
|
||||||
import DataUtils._
|
|
||||||
val cacheInfo = {
|
|
||||||
if ($(useExternalMemory)) {
|
|
||||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
|
|
||||||
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val dm = new DMatrix(
|
|
||||||
processMissingValues(
|
|
||||||
features.map(_.asXGB),
|
|
||||||
$(missing),
|
|
||||||
$(allowNonZeroForMissing)
|
|
||||||
),
|
|
||||||
cacheInfo)
|
|
||||||
try {
|
|
||||||
val Array(rawPredictionItr, predLeafItr, predContribItr) =
|
|
||||||
producePredictionItrs(bBooster, dm)
|
|
||||||
produceResultIterator(batchRow.iterator, rawPredictionItr, predLeafItr, predContribItr)
|
|
||||||
} finally {
|
|
||||||
batchCnt += 1
|
|
||||||
dm.delete()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def hasNext: Boolean = batchIterImpl.hasNext
|
|
||||||
|
|
||||||
override def next(): Row = {
|
|
||||||
val ret = batchIterImpl.next()
|
|
||||||
if (!batchIterImpl.hasNext) {
|
|
||||||
Rabit.shutdown()
|
|
||||||
}
|
|
||||||
ret
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bBooster.unpersist(blocking = false)
|
|
||||||
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
|
|
||||||
}
|
|
||||||
|
|
||||||
private def produceResultIterator(
|
|
||||||
originalRowItr: Iterator[Row],
|
originalRowItr: Iterator[Row],
|
||||||
predictionItr: Iterator[Row],
|
predictionItr: Iterator[Row],
|
||||||
predLeafItr: Iterator[Row],
|
predLeafItr: Iterator[Row],
|
||||||
@ -353,20 +283,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def generateResultSchema(fixedSchema: StructType): StructType = {
|
private[spark] def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
|
||||||
var resultSchema = fixedSchema
|
|
||||||
if (isDefined(leafPredictionCol)) {
|
|
||||||
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
if (isDefined(contribPredictionCol)) {
|
|
||||||
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
|
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
|
||||||
}
|
|
||||||
resultSchema
|
|
||||||
}
|
|
||||||
|
|
||||||
private def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
|
|
||||||
Array[Iterator[Row]] = {
|
Array[Iterator[Row]] = {
|
||||||
val originalPredictionItr = {
|
val originalPredictionItr = {
|
||||||
booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
|
booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
|
||||||
@ -394,7 +311,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
// Output selected columns only.
|
// Output selected columns only.
|
||||||
// This is a bit complicated since it tries to avoid repeated computation.
|
// This is a bit complicated since it tries to avoid repeated computation.
|
||||||
var outputData = transformInternal(dataset)
|
var outputData = PreXGBoost.transformDataFrame(this, dataset)
|
||||||
var numColsOutput = 0
|
var numColsOutput = 0
|
||||||
|
|
||||||
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
|
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
|
||||||
@ -425,7 +342,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
|
|
||||||
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
||||||
|
|
||||||
private val _originalPredictionCol = "_originalPrediction"
|
private[spark] val _originalPredictionCol = "_originalPrediction"
|
||||||
|
|
||||||
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader
|
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user