diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index a0fa7e943..6956db215 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -59,3 +59,4 @@ List of Contributors * [Sam Thomson](https://github.com/sammthomson) * [ganesh-krishnan](https://github.com/ganesh-krishnan) * [Damien Carol](https://github.com/damiencarol) +* [Alex Bain](https://github.com/convexquad) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index ea8e8a5e5..9aa5e84dc 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -26,9 +26,9 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} class XGBoostModel(_booster: Booster) extends Serializable { /** - * Predict result with the given testset (represented as RDD) + * Predict result with the given test set (represented as RDD) * - * @param testSet test set representd as RDD + * @param testSet test set represented as RDD * @param useExternalCache whether to use external cache for the test set */ def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = { @@ -53,8 +53,9 @@ class XGBoostModel(_booster: Booster) extends Serializable { } /** - * Predict result with the given testset (represented as RDD) - * @param testSet test set representd as RDD + * Predict result with the given test set (represented as RDD) + * + * @param testSet test set represented as RDD * @param missingValue the specified value to represent the missing value */ def predict(testSet: RDD[DenseVector], missingValue: Float): RDD[Array[Array[Float]]] = { @@ -78,12 +79,41 @@ class XGBoostModel(_booster: Booster) extends Serializable { } /** - * predict result given the test data (represented as DMatrix) + * Predict result with the given test set (represented as DMatrix) + * + * @param testSet test set represented as DMatrix */ def predict(testSet: DMatrix): Array[Array[Float]] = { _booster.predict(testSet, true, 0) } + /** + * Predict leaf instances with the given test set (represented as RDD) + * + * @param testSet test set represented as RDD + */ + def predictLeaves(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = { + import DataUtils._ + val broadcastBooster = testSet.sparkContext.broadcast(_booster) + testSet.mapPartitions { testSamples => + if (testSamples.hasNext) { + val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) + Iterator(broadcastBooster.value.predictLeaf(dMatrix, 0)) + } else { + Iterator() + } + } + } + + /** + * Predict leaf instances with the given test set (represented as DMatrix) + * + * @param testSet test set represented as DMatrix + */ + def predictLeaves(testSet: DMatrix): Array[Array[Float]] = { + _booster.predictLeaf(testSet, 0) + } + /** * Save the model as to HDFS-compatible file system. * @@ -97,7 +127,7 @@ class XGBoostModel(_booster: Booster) extends Serializable { } /** - * get the booster instance of this model + * Get the booster instance of this model */ def booster: Booster = _booster }