[JVM-packages] Support single instance prediction. (#3464)
* Support single instance prediction. * Address comments.
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.Iterator
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
@@ -229,30 +230,28 @@ class XGBoostClassificationModel private[ml](
|
||||
this
|
||||
}
|
||||
|
||||
// TODO: Make it public after we resolve performance issue
|
||||
private def margin(features: Vector): Array[Float] = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
|
||||
_booster.predict(data = dm, outPutMargin = true)(0)
|
||||
}
|
||||
|
||||
private def probability(features: Vector): Array[Float] = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
|
||||
_booster.predict(data = dm, outPutMargin = false)(0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Single instance prediction.
|
||||
* Note: The performance is not ideal, use it carefully!
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
throw new Exception("XGBoost-Spark does not support online prediction")
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(Iterator(features.asXGB))
|
||||
val probability = _booster.predict(data = dm)(0)
|
||||
if (numClasses == 2) {
|
||||
math.round(probability(0))
|
||||
} else {
|
||||
Vectors.dense(probability.map(_.toDouble)).argmax
|
||||
}
|
||||
}
|
||||
|
||||
// Actually we don't use this function at all, to make it pass compiler check.
|
||||
override def predictRaw(features: Vector): Vector = {
|
||||
override protected def predictRaw(features: Vector): Vector = {
|
||||
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
||||
}
|
||||
|
||||
// Actually we don't use this function at all, to make it pass compiler check.
|
||||
override def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.Iterator
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
@@ -225,8 +226,14 @@ class XGBoostRegressionModel private[ml] (
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Single instance prediction.
|
||||
* Note: The performance is not ideal, use it carefully!
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
throw new Exception("XGBoost-Spark does not support online prediction")
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(Iterator(features.asXGB))
|
||||
_booster.predict(data = dm)(0)(0)
|
||||
}
|
||||
|
||||
private def transformInternal(dataset: Dataset[_]): DataFrame = {
|
||||
|
||||
Reference in New Issue
Block a user