diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala index e457ef405..be8893401 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala @@ -18,23 +18,30 @@ package ml.dmlc.xgboost4j.scala.spark import java.nio.file.Files -import scala.collection.{AbstractIterator, mutable} +import scala.collection.JavaConverters._ +import scala.collection.{AbstractIterator, Iterator, mutable} +import ml.dmlc.xgboost4j.java.Rabit +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams +import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel._originalPredictionCol import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext -import org.apache.spark.ml.Estimator +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} import org.apache.spark.storage.StorageLevel /** - * PreXGBoost converts Dataset[_] to RDD[[Watches]] + * PreXGBoost serves preparing data before training and transform */ object PreXGBoost { @@ -117,6 +124,131 @@ object PreXGBoost { } + /** + * Transform Dataset + * + * @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]] + * @param dataset the input Dataset to transform + * @return the transformed DataFrame + */ + def transformDataFrame(model: Model[_], dataset: Dataset[_]): DataFrame = { + + /** get the necessary parameters */ + val (booster, inferBatchSize, featuresCol, useExternalMemory, missing, allowNonZeroForMissing, + predictFunc, schema) = + model match { + case m: XGBoostClassificationModel => + + // predict and turn to Row + val predictFunc = + (broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { + val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = + m.producePredictionItrs(broadcastBooster, dm) + m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr, + predLeafItr, predContribItr) + } + + // prepare the final Schema + var schema = StructType(dataset.schema.fields ++ + Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) ++ + Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false))) + + if (m.isDefined(m.leafPredictionCol)) { + schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) + } + if (m.isDefined(m.contribPredictionCol)) { + schema = schema.add(StructField(name = m.getContribPredictionCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) + } + + (m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing, + m.getAllowNonZeroForMissingValue, predictFunc, schema) + + case m: XGBoostRegressionModel => + // predict and turn to Row + val predictFunc = + (broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { + val Array(rawPredictionItr, predLeafItr, predContribItr) = + m.producePredictionItrs(broadcastBooster, dm) + m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr, predContribItr) + } + + // prepare the final Schema + var schema = StructType(dataset.schema.fields ++ + Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false))) + + if (m.isDefined(m.leafPredictionCol)) { + schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) + } + if (m.isDefined(m.contribPredictionCol)) { + schema = schema.add(StructField(name = m.getContribPredictionCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) + } + + (m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing, + m.getAllowNonZeroForMissingValue, predictFunc, schema) + } + + val bBooster = dataset.sparkSession.sparkContext.broadcast(booster) + val appName = dataset.sparkSession.sparkContext.appName + + val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => + new AbstractIterator[Row] { + private var batchCnt = 0 + + private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow => + if (batchCnt == 0) { + val rabitEnv = Array( + "DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap + Rabit.init(rabitEnv.asJava) + } + + val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol)) + + import DataUtils._ + val cacheInfo = { + if (useExternalMemory) { + s"$appName-${TaskContext.get().stageId()}-dtest_cache-" + + s"${TaskContext.getPartitionId()}-batch-$batchCnt" + } else { + null + } + } + + val dm = new DMatrix( + processMissingValues(features.map(_.asXGB), missing, allowNonZeroForMissing), + cacheInfo) + + try { + predictFunc(bBooster, dm, batchRow.iterator) + } finally { + batchCnt += 1 + dm.delete() + } + } + + override def hasNext: Boolean = batchIterImpl.hasNext + + override def next(): Row = { + val ret = batchIterImpl.next() + if (!batchIterImpl.hasNext) { + Rabit.shutdown() + } + ret + } + } + } + + bBooster.unpersist(blocking = false) + dataset.sparkSession.createDataFrame(resultRDD, schema) + } + + /** * Converting the RDD[XGBLabeledPoint] to the function to build RDD[Watches] * diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 30c701a7b..ed948b0f1 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -16,25 +16,19 @@ package ml.dmlc.xgboost4j.scala.spark -import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.scala.spark.params._ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost} -import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.classification._ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ import org.json4s.DefaultFormats -import scala.collection.JavaConverters._ -import scala.collection.{AbstractIterator, Iterator, mutable} +import scala.collection.{Iterator, mutable} class XGBoostClassifier ( override val uid: String, @@ -277,76 +271,7 @@ class XGBoostClassificationModel private[ml]( throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'") } - // Generate raw prediction and probability prediction. - private def transformInternal(dataset: Dataset[_]): DataFrame = { - - val schema = StructType(dataset.schema.fields ++ - Seq(StructField(name = _rawPredictionCol, dataType = - ArrayType(FloatType, containsNull = false), nullable = false)) ++ - Seq(StructField(name = _probabilityCol, dataType = - ArrayType(FloatType, containsNull = false), nullable = false))) - - val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) - val appName = dataset.sparkSession.sparkContext.appName - - val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => - new AbstractIterator[Row] { - private var batchCnt = 0 - - private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow => - if (batchCnt == 0) { - val rabitEnv = Array( - "DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap - Rabit.init(rabitEnv.asJava) - } - - val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol))) - - import DataUtils._ - val cacheInfo = { - if ($(useExternalMemory)) { - s"$appName-${TaskContext.get().stageId()}-dtest_cache-" + - s"${TaskContext.getPartitionId()}-batch-$batchCnt" - } else { - null - } - } - - val dm = new DMatrix( - processMissingValues( - features.map(_.asXGB), - $(missing), - $(allowNonZeroForMissing) - ), - cacheInfo) - try { - val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = - producePredictionItrs(bBooster, dm) - produceResultIterator(batchRow.iterator, - rawPredictionItr, probabilityItr, predLeafItr, predContribItr) - } finally { - batchCnt += 1 - dm.delete() - } - } - - override def hasNext: Boolean = batchIterImpl.hasNext - - override def next(): Row = { - val ret = batchIterImpl.next() - if (!batchIterImpl.hasNext) { - Rabit.shutdown() - } - ret - } - } - } - - bBooster.unpersist(blocking = false) - dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema)) - } - - private def produceResultIterator( + private[spark] def produceResultIterator( originalRowItr: Iterator[Row], rawPredictionItr: Iterator[Row], probabilityItr: Iterator[Row], @@ -381,20 +306,7 @@ class XGBoostClassificationModel private[ml]( } } - private def generateResultSchema(fixedSchema: StructType): StructType = { - var resultSchema = fixedSchema - if (isDefined(leafPredictionCol)) { - resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType = - ArrayType(FloatType, containsNull = false), nullable = false)) - } - if (isDefined(contribPredictionCol)) { - resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType = - ArrayType(FloatType, containsNull = false), nullable = false)) - } - resultSchema - } - - private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix): + private[spark] def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix): Array[Iterator[Row]] = { val rawPredictionItr = { broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)). @@ -431,7 +343,7 @@ class XGBoostClassificationModel private[ml]( // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. - var outputData = transformInternal(dataset) + var outputData = PreXGBoost.transformDataFrame(this, dataset) var numColsOutput = 0 val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] => @@ -492,8 +404,8 @@ class XGBoostClassificationModel private[ml]( object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] { - private val _rawPredictionCol = "_rawPrediction" - private val _probabilityCol = "_probability" + private[spark] val _rawPredictionCol = "_rawPrediction" + private[spark] val _probabilityCol = "_probability" override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 6810a1bb7..ff446cb08 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -16,25 +16,19 @@ package ml.dmlc.xgboost4j.scala.spark -import scala.collection.{AbstractIterator, Iterator, mutable} -import scala.collection.JavaConverters._ +import scala.collection.{Iterator, mutable} -import ml.dmlc.xgboost4j.java.Rabit -import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.util._ import org.apache.spark.ml._ import org.apache.spark.ml.param._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ import org.json4s.DefaultFormats import org.apache.spark.broadcast.Broadcast @@ -257,71 +251,7 @@ class XGBoostRegressionModel private[ml] ( _booster.predict(data = dm)(0)(0) } - private def transformInternal(dataset: Dataset[_]): DataFrame = { - - val schema = StructType(dataset.schema.fields ++ - Seq(StructField(name = _originalPredictionCol, dataType = - ArrayType(FloatType, containsNull = false), nullable = false))) - - val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) - val appName = dataset.sparkSession.sparkContext.appName - - val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => - new AbstractIterator[Row] { - private var batchCnt = 0 - - private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow => - if (batchCnt == 0) { - val rabitEnv = Array( - "DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap - Rabit.init(rabitEnv.asJava) - } - - val features = batchRow.iterator.map(row => row.getAs[Vector]($(featuresCol))) - - import DataUtils._ - val cacheInfo = { - if ($(useExternalMemory)) { - s"$appName-${TaskContext.get().stageId()}-dtest_cache-" + - s"${TaskContext.getPartitionId()}-batch-$batchCnt" - } else { - null - } - } - - val dm = new DMatrix( - processMissingValues( - features.map(_.asXGB), - $(missing), - $(allowNonZeroForMissing) - ), - cacheInfo) - try { - val Array(rawPredictionItr, predLeafItr, predContribItr) = - producePredictionItrs(bBooster, dm) - produceResultIterator(batchRow.iterator, rawPredictionItr, predLeafItr, predContribItr) - } finally { - batchCnt += 1 - dm.delete() - } - } - - override def hasNext: Boolean = batchIterImpl.hasNext - - override def next(): Row = { - val ret = batchIterImpl.next() - if (!batchIterImpl.hasNext) { - Rabit.shutdown() - } - ret - } - } - } - bBooster.unpersist(blocking = false) - dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema)) - } - - private def produceResultIterator( + private[spark] def produceResultIterator( originalRowItr: Iterator[Row], predictionItr: Iterator[Row], predLeafItr: Iterator[Row], @@ -353,20 +283,7 @@ class XGBoostRegressionModel private[ml] ( } } - private def generateResultSchema(fixedSchema: StructType): StructType = { - var resultSchema = fixedSchema - if (isDefined(leafPredictionCol)) { - resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType = - ArrayType(FloatType, containsNull = false), nullable = false)) - } - if (isDefined(contribPredictionCol)) { - resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType = - ArrayType(FloatType, containsNull = false), nullable = false)) - } - resultSchema - } - - private def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix): + private[spark] def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix): Array[Iterator[Row]] = { val originalPredictionItr = { booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator @@ -394,7 +311,7 @@ class XGBoostRegressionModel private[ml] ( transformSchema(dataset.schema, logging = true) // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. - var outputData = transformInternal(dataset) + var outputData = PreXGBoost.transformDataFrame(this, dataset) var numColsOutput = 0 val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) => @@ -425,7 +342,7 @@ class XGBoostRegressionModel private[ml] ( object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] { - private val _originalPredictionCol = "_originalPrediction" + private[spark] val _originalPredictionCol = "_originalPrediction" override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader