Fixed a signature of XGBoostModel.predict (#2476)

Prior to this commit XGBoostModel.predict produced an RDD with
an array of predictions for each partition, effectively changing
the shape wrt the input RDD. A more natural contract for prediction
API is that given an RDD it returns a new RDD with the same number
of elements. This allows the users to easily match inputs with
predictions.

This commit removes one layer of nesting in XGBoostModel.predict output.
Even though the change is clearly non-backward compatible, I still
think it is well justified. See discussion in 06bd5dca for motivation.
This commit is contained in:
Sergei Lebedev 2017-07-03 06:42:46 +02:00 committed by Nan Zhu
parent ed8bc4521e
commit 8ceeb32bad
2 changed files with 14 additions and 18 deletions

View File

@ -59,19 +59,18 @@ abstract class XGBoostModel(protected var _booster: Booster)
* *
* @param testSet test set represented as RDD * @param testSet test set represented as RDD
*/ */
def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Array[Float]]] = { def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Float]] = {
import DataUtils._ import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster) val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples => testSet.mapPartitions { testSamples =>
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
if (testSamples.hasNext) { if (testSamples.nonEmpty) {
val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
try { try {
val res = broadcastBooster.value.predictLeaf(dMatrix) broadcastBooster.value.predictLeaf(dMatrix).iterator
Rabit.shutdown()
Iterator(res)
} finally { } finally {
Rabit.shutdown()
dMatrix.delete() dMatrix.delete()
} }
} else { } else {
@ -151,7 +150,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
* @param testSet test set represented as RDD * @param testSet test set represented as RDD
* @param missingValue the specified value to represent the missing value * @param missingValue the specified value to represent the missing value
*/ */
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Array[Float]]] = { def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster) val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples => testSet.mapPartitions { testSamples =>
val sampleArray = testSamples.toList val sampleArray = testSamples.toList
@ -169,7 +168,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
} }
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
try { try {
Iterator(broadcastBooster.value.predict(dMatrix)) broadcastBooster.value.predict(dMatrix).iterator
} finally { } finally {
Rabit.shutdown() Rabit.shutdown()
dMatrix.delete() dMatrix.delete()
@ -188,7 +187,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
def predict( def predict(
testSet: RDD[MLVector], testSet: RDD[MLVector],
useExternalCache: Boolean = false, useExternalCache: Boolean = false,
outputMargin: Boolean = false): RDD[Array[Array[Float]]] = { outputMargin: Boolean = false): RDD[Array[Float]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster) val broadcastBooster = testSet.sparkContext.broadcast(_booster)
val appName = testSet.context.appName val appName = testSet.context.appName
testSet.mapPartitions { testSamples => testSet.mapPartitions { testSamples =>
@ -205,7 +204,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
} }
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName)) val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
try { try {
Iterator(broadcastBooster.value.predict(dMatrix)) broadcastBooster.value.predict(dMatrix).iterator
} finally { } finally {
Rabit.shutdown() Rabit.shutdown()
dMatrix.delete() dMatrix.delete()

View File

@ -252,7 +252,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
"objective" -> "binary:logistic") "objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val predRDD = xgBoostModel.predict(testRDD) val predRDD = xgBoostModel.predict(testRDD)
val predResult1 = predRDD.collect()(0) val predResult1 = predRDD.collect()
assert(testRDD.count() === predResult1.length) assert(testRDD.count() === predResult1.length)
import DataUtils._ import DataUtils._
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator)) val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
@ -273,14 +273,11 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
test("test prediction functionality with empty partition") { test("test prediction functionality with empty partition") {
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = { def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
val sampleList = new ListBuffer[SparkVector] sparkContext.getOrElse(sc).parallelize(List[SparkVector](), numWorkers)
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
} }
val trainingRDD = buildTrainingRDD(sc) val trainingRDD = buildTrainingRDD(sc)
val testRDD = buildEmptyRDD() val testRDD = buildEmptyRDD()
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap "objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
@ -358,7 +355,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1) val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1)
val predRDD = xgBoostModel.predict(testRDD) val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()(0) val predResult1: Array[Array[Float]] = predRDD.collect()
assert(testRDD.count() === predResult1.length) assert(testRDD.count() === predResult1.length)
val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData) val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData)
@ -386,7 +383,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2) val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2)
val predRDD = xgBoostModel.predict(testRDD) val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()(0) val predResult1: Array[Array[Float]] = predRDD.collect()
assert(testRDD.count() === predResult1.length) assert(testRDD.count() === predResult1.length)
} }
@ -403,7 +400,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
val trainMargin = { val trainMargin = {
XGBoost.trainWithRDD(trainRDD, paramMap, round = 1, nWorkers = 2) XGBoost.trainWithRDD(trainRDD, paramMap, round = 1, nWorkers = 2)
.predict(trainRDD.map(_.features), outputMargin = true) .predict(trainRDD.map(_.features), outputMargin = true)
.flatMap { _.flatten.iterator } .map { case Array(m) => m }
} }
val xgBoostModel = XGBoost.trainWithRDD( val xgBoostModel = XGBoost.trainWithRDD(
@ -413,6 +410,6 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
nWorkers = 2, nWorkers = 2,
baseMargin = trainMargin) baseMargin = trainMargin)
assert(testRDD.count() === xgBoostModel.predict(testRDD).first().length) assert(testRDD.count() === xgBoostModel.predict(testRDD).count())
} }
} }