diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index c2e53e9a4..4f3d50f9f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -16,6 +16,7 @@ package ml.dmlc.xgboost4j.scala.spark +import scala.collection.Iterator import scala.collection.JavaConverters._ import scala.collection.mutable @@ -229,30 +230,28 @@ class XGBoostClassificationModel private[ml]( this } - // TODO: Make it public after we resolve performance issue - private def margin(features: Vector): Array[Float] = { - import DataUtils._ - val dm = new DMatrix(scala.collection.Iterator(features.asXGB)) - _booster.predict(data = dm, outPutMargin = true)(0) - } - - private def probability(features: Vector): Array[Float] = { - import DataUtils._ - val dm = new DMatrix(scala.collection.Iterator(features.asXGB)) - _booster.predict(data = dm, outPutMargin = false)(0) - } - + /** + * Single instance prediction. + * Note: The performance is not ideal, use it carefully! + */ override def predict(features: Vector): Double = { - throw new Exception("XGBoost-Spark does not support online prediction") + import DataUtils._ + val dm = new DMatrix(Iterator(features.asXGB)) + val probability = _booster.predict(data = dm)(0) + if (numClasses == 2) { + math.round(probability(0)) + } else { + Vectors.dense(probability.map(_.toDouble)).argmax + } } // Actually we don't use this function at all, to make it pass compiler check. - override def predictRaw(features: Vector): Vector = { + override protected def predictRaw(features: Vector): Vector = { throw new Exception("XGBoost-Spark does not support \'predictRaw\'") } // Actually we don't use this function at all, to make it pass compiler check. - override def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'") } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 93c4b7446..29f289102 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -16,6 +16,7 @@ package ml.dmlc.xgboost4j.scala.spark +import scala.collection.Iterator import scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.java.Rabit @@ -225,8 +226,14 @@ class XGBoostRegressionModel private[ml] ( this } + /** + * Single instance prediction. + * Note: The performance is not ideal, use it carefully! + */ override def predict(features: Vector): Double = { - throw new Exception("XGBoost-Spark does not support online prediction") + import DataUtils._ + val dm = new DMatrix(Iterator(features.asXGB)) + _booster.predict(data = dm)(0)(0) } private def transformInternal(dataset: Dataset[_]): DataFrame = { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 4bb9e2c8c..d2814b8a1 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -66,6 +66,13 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { assert(prediction3(i)(j) === prediction4(i)(j)) } } + + // check the equality of single instance prediction + val firstOfDM = testDM.slice(Array(0)) + val firstOfDF = testDF.head().getAs[Vector]("features") + val prediction5 = math.round(model1.predict(firstOfDM)(0)(0)) + val prediction6 = model2.predict(firstOfDF) + assert(prediction5 === prediction6) } test("Set params in XGBoost and MLlib way should produce same model") { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index 86aa96d57..8dba73e61 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.functions._ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ @@ -49,6 +50,14 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { assert(prediction1.indices.count { i => math.abs(prediction1(i)(0) - prediction2(i)) > 0.01 } < prediction1.length * 0.1) + + + // check the equality of single instance prediction + val firstOfDM = testDM.slice(Array(0)) + val firstOfDF = testDF.head().getAs[Vector]("features") + val prediction3 = model1.predict(firstOfDM)(0)(0) + val prediction4 = model2.predict(firstOfDF) + assert(math.abs(prediction3 - prediction4) <= 0.01f) } test("Set params in XGBoost and MLlib way should produce same model") {