Expose predictLeaf functionality in Scala XGBoostModel (#1351)

This commit is contained in:
convexquad 2016-07-12 03:55:24 -07:00 committed by Nan Zhu
parent 75d9be55de
commit 313764b3be
2 changed files with 37 additions and 6 deletions

View File

@ -59,3 +59,4 @@ List of Contributors
* [Sam Thomson](https://github.com/sammthomson) * [Sam Thomson](https://github.com/sammthomson)
* [ganesh-krishnan](https://github.com/ganesh-krishnan) * [ganesh-krishnan](https://github.com/ganesh-krishnan)
* [Damien Carol](https://github.com/damiencarol) * [Damien Carol](https://github.com/damiencarol)
* [Alex Bain](https://github.com/convexquad)

View File

@ -28,7 +28,7 @@ class XGBoostModel(_booster: Booster) extends Serializable {
/** /**
* Predict result with the given test set (represented as RDD) * Predict result with the given test set (represented as RDD)
* *
* @param testSet test set representd as RDD * @param testSet test set represented as RDD
* @param useExternalCache whether to use external cache for the test set * @param useExternalCache whether to use external cache for the test set
*/ */
def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = { def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = {
@ -54,7 +54,8 @@ class XGBoostModel(_booster: Booster) extends Serializable {
/** /**
* Predict result with the given test set (represented as RDD) * Predict result with the given test set (represented as RDD)
* @param testSet test set representd as RDD *
* @param testSet test set represented as RDD
* @param missingValue the specified value to represent the missing value * @param missingValue the specified value to represent the missing value
*/ */
def predict(testSet: RDD[DenseVector], missingValue: Float): RDD[Array[Array[Float]]] = { def predict(testSet: RDD[DenseVector], missingValue: Float): RDD[Array[Array[Float]]] = {
@ -78,12 +79,41 @@ class XGBoostModel(_booster: Booster) extends Serializable {
} }
/** /**
* predict result given the test data (represented as DMatrix) * Predict result with the given test set (represented as DMatrix)
*
* @param testSet test set represented as DMatrix
*/ */
def predict(testSet: DMatrix): Array[Array[Float]] = { def predict(testSet: DMatrix): Array[Array[Float]] = {
_booster.predict(testSet, true, 0) _booster.predict(testSet, true, 0)
} }
/**
* Predict leaf instances with the given test set (represented as RDD)
*
* @param testSet test set represented as RDD
*/
def predictLeaves(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
if (testSamples.hasNext) {
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
Iterator(broadcastBooster.value.predictLeaf(dMatrix, 0))
} else {
Iterator()
}
}
}
/**
* Predict leaf instances with the given test set (represented as DMatrix)
*
* @param testSet test set represented as DMatrix
*/
def predictLeaves(testSet: DMatrix): Array[Array[Float]] = {
_booster.predictLeaf(testSet, 0)
}
/** /**
* Save the model as to HDFS-compatible file system. * Save the model as to HDFS-compatible file system.
* *
@ -97,7 +127,7 @@ class XGBoostModel(_booster: Booster) extends Serializable {
} }
/** /**
* get the booster instance of this model * Get the booster instance of this model
*/ */
def booster: Booster = _booster def booster: Booster = _booster
} }