[jvm-packages] create dmatrix with specified missing value (#1272)
* create dmatrix with specified missing value * update dmlc-core * support for predict method in spark package repartitioning work around * add more elements to work around training set empty partition issue
This commit is contained in:
@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
import org.apache.spark.{TaskContext, SparkContext}
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
@@ -27,6 +27,7 @@ class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Ser
|
||||
|
||||
/**
|
||||
* Predict result with the given testset (represented as RDD)
|
||||
*
|
||||
* @param testSet test set representd as RDD
|
||||
* @param useExternalCache whether to use external cache for the test set
|
||||
*/
|
||||
@@ -51,6 +52,31 @@ class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Ser
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict result with the given testset (represented as RDD)
|
||||
* @param testSet test set representd as RDD
|
||||
* @param missingValue the specified value to represent the missing value
|
||||
*/
|
||||
def predict(testSet: RDD[DenseVector], missingValue: Float): RDD[Array[Array[Float]]] = {
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val sampleArray = testSamples.toList
|
||||
val numRows = sampleArray.size
|
||||
val numColumns = sampleArray.head.size
|
||||
if (numRows == 0) {
|
||||
Iterator()
|
||||
} else {
|
||||
// translate to required format
|
||||
val flatSampleArray = new Array[Float](numRows * numColumns)
|
||||
for (i <- flatSampleArray.indices) {
|
||||
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
||||
}
|
||||
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* predict result given the test data (represented as DMatrix)
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user