Expose predictLeaf functionality in Scala XGBoostModel (#1351)
This commit is contained in:
parent
75d9be55de
commit
313764b3be
@ -59,3 +59,4 @@ List of Contributors
|
|||||||
* [Sam Thomson](https://github.com/sammthomson)
|
* [Sam Thomson](https://github.com/sammthomson)
|
||||||
* [ganesh-krishnan](https://github.com/ganesh-krishnan)
|
* [ganesh-krishnan](https://github.com/ganesh-krishnan)
|
||||||
* [Damien Carol](https://github.com/damiencarol)
|
* [Damien Carol](https://github.com/damiencarol)
|
||||||
|
* [Alex Bain](https://github.com/convexquad)
|
||||||
|
|||||||
@ -26,9 +26,9 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
|||||||
class XGBoostModel(_booster: Booster) extends Serializable {
|
class XGBoostModel(_booster: Booster) extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict result with the given testset (represented as RDD)
|
* Predict result with the given test set (represented as RDD)
|
||||||
*
|
*
|
||||||
* @param testSet test set representd as RDD
|
* @param testSet test set represented as RDD
|
||||||
* @param useExternalCache whether to use external cache for the test set
|
* @param useExternalCache whether to use external cache for the test set
|
||||||
*/
|
*/
|
||||||
def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = {
|
def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = {
|
||||||
@ -53,8 +53,9 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict result with the given testset (represented as RDD)
|
* Predict result with the given test set (represented as RDD)
|
||||||
* @param testSet test set representd as RDD
|
*
|
||||||
|
* @param testSet test set represented as RDD
|
||||||
* @param missingValue the specified value to represent the missing value
|
* @param missingValue the specified value to represent the missing value
|
||||||
*/
|
*/
|
||||||
def predict(testSet: RDD[DenseVector], missingValue: Float): RDD[Array[Array[Float]]] = {
|
def predict(testSet: RDD[DenseVector], missingValue: Float): RDD[Array[Array[Float]]] = {
|
||||||
@ -78,12 +79,41 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* predict result given the test data (represented as DMatrix)
|
* Predict result with the given test set (represented as DMatrix)
|
||||||
|
*
|
||||||
|
* @param testSet test set represented as DMatrix
|
||||||
*/
|
*/
|
||||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||||
_booster.predict(testSet, true, 0)
|
_booster.predict(testSet, true, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predict leaf instances with the given test set (represented as RDD)
|
||||||
|
*
|
||||||
|
* @param testSet test set represented as RDD
|
||||||
|
*/
|
||||||
|
def predictLeaves(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
||||||
|
import DataUtils._
|
||||||
|
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||||
|
testSet.mapPartitions { testSamples =>
|
||||||
|
if (testSamples.hasNext) {
|
||||||
|
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||||
|
Iterator(broadcastBooster.value.predictLeaf(dMatrix, 0))
|
||||||
|
} else {
|
||||||
|
Iterator()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predict leaf instances with the given test set (represented as DMatrix)
|
||||||
|
*
|
||||||
|
* @param testSet test set represented as DMatrix
|
||||||
|
*/
|
||||||
|
def predictLeaves(testSet: DMatrix): Array[Array[Float]] = {
|
||||||
|
_booster.predictLeaf(testSet, 0)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save the model as to HDFS-compatible file system.
|
* Save the model as to HDFS-compatible file system.
|
||||||
*
|
*
|
||||||
@ -97,7 +127,7 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get the booster instance of this model
|
* Get the booster instance of this model
|
||||||
*/
|
*/
|
||||||
def booster: Booster = _booster
|
def booster: Booster = _booster
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user