[JVM-packages] Support single instance prediction. (#3464)
* Support single instance prediction. * Address comments.
This commit is contained in:
@@ -66,6 +66,13 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
assert(prediction3(i)(j) === prediction4(i)(j))
|
||||
}
|
||||
}
|
||||
|
||||
// check the equality of single instance prediction
|
||||
val firstOfDM = testDM.slice(Array(0))
|
||||
val firstOfDF = testDF.head().getAs[Vector]("features")
|
||||
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
|
||||
val prediction6 = model2.predict(firstOfDF)
|
||||
assert(prediction5 === prediction6)
|
||||
}
|
||||
|
||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.types._
|
||||
@@ -49,6 +50,14 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
assert(prediction1.indices.count { i =>
|
||||
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
|
||||
} < prediction1.length * 0.1)
|
||||
|
||||
|
||||
// check the equality of single instance prediction
|
||||
val firstOfDM = testDM.slice(Array(0))
|
||||
val firstOfDF = testDF.head().getAs[Vector]("features")
|
||||
val prediction3 = model1.predict(firstOfDM)(0)(0)
|
||||
val prediction4 = model2.predict(firstOfDF)
|
||||
assert(math.abs(prediction3 - prediction4) <= 0.01f)
|
||||
}
|
||||
|
||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
||||
|
||||
Reference in New Issue
Block a user