[JVM-packages] Support single instance prediction. (#3464)
* Support single instance prediction. * Address comments.
This commit is contained in:
parent
2200939416
commit
2f8764955c
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import scala.collection.Iterator
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
@ -229,30 +230,28 @@ class XGBoostClassificationModel private[ml](
|
|||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Make it public after we resolve performance issue
|
/**
|
||||||
private def margin(features: Vector): Array[Float] = {
|
* Single instance prediction.
|
||||||
import DataUtils._
|
* Note: The performance is not ideal, use it carefully!
|
||||||
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
|
*/
|
||||||
_booster.predict(data = dm, outPutMargin = true)(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def probability(features: Vector): Array[Float] = {
|
|
||||||
import DataUtils._
|
|
||||||
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
|
|
||||||
_booster.predict(data = dm, outPutMargin = false)(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def predict(features: Vector): Double = {
|
override def predict(features: Vector): Double = {
|
||||||
throw new Exception("XGBoost-Spark does not support online prediction")
|
import DataUtils._
|
||||||
|
val dm = new DMatrix(Iterator(features.asXGB))
|
||||||
|
val probability = _booster.predict(data = dm)(0)
|
||||||
|
if (numClasses == 2) {
|
||||||
|
math.round(probability(0))
|
||||||
|
} else {
|
||||||
|
Vectors.dense(probability.map(_.toDouble)).argmax
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Actually we don't use this function at all, to make it pass compiler check.
|
// Actually we don't use this function at all, to make it pass compiler check.
|
||||||
override def predictRaw(features: Vector): Vector = {
|
override protected def predictRaw(features: Vector): Vector = {
|
||||||
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Actually we don't use this function at all, to make it pass compiler check.
|
// Actually we don't use this function at all, to make it pass compiler check.
|
||||||
override def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||||
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import scala.collection.Iterator
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
@ -225,8 +226,14 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Single instance prediction.
|
||||||
|
* Note: The performance is not ideal, use it carefully!
|
||||||
|
*/
|
||||||
override def predict(features: Vector): Double = {
|
override def predict(features: Vector): Double = {
|
||||||
throw new Exception("XGBoost-Spark does not support online prediction")
|
import DataUtils._
|
||||||
|
val dm = new DMatrix(Iterator(features.asXGB))
|
||||||
|
_booster.predict(data = dm)(0)(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def transformInternal(dataset: Dataset[_]): DataFrame = {
|
private def transformInternal(dataset: Dataset[_]): DataFrame = {
|
||||||
|
|||||||
@ -66,6 +66,13 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
assert(prediction3(i)(j) === prediction4(i)(j))
|
assert(prediction3(i)(j) === prediction4(i)(j))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check the equality of single instance prediction
|
||||||
|
val firstOfDM = testDM.slice(Array(0))
|
||||||
|
val firstOfDF = testDF.head().getAs[Vector]("features")
|
||||||
|
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
|
||||||
|
val prediction6 = model2.predict(firstOfDF)
|
||||||
|
assert(prediction5 === prediction6)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
test("Set params in XGBoost and MLlib way should produce same model") {
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.Row
|
import org.apache.spark.sql.Row
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
@ -49,6 +50,14 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
assert(prediction1.indices.count { i =>
|
assert(prediction1.indices.count { i =>
|
||||||
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
|
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
|
||||||
} < prediction1.length * 0.1)
|
} < prediction1.length * 0.1)
|
||||||
|
|
||||||
|
|
||||||
|
// check the equality of single instance prediction
|
||||||
|
val firstOfDM = testDM.slice(Array(0))
|
||||||
|
val firstOfDF = testDF.head().getAs[Vector]("features")
|
||||||
|
val prediction3 = model1.predict(firstOfDM)(0)(0)
|
||||||
|
val prediction4 = model2.predict(firstOfDF)
|
||||||
|
assert(math.abs(prediction3 - prediction4) <= 0.01f)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
test("Set params in XGBoost and MLlib way should produce same model") {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user