diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index ec76c9177..9773e3447 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -251,11 +251,11 @@ class XGBoostClassificationModel private[ml]( override def predict(features: Vector): Double = { import DataUtils._ val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing))) - val probability = _booster.predict(data = dm)(0) + val probability = _booster.predict(data = dm)(0).map(_.toDouble) if (numClasses == 2) { math.round(probability(0)) } else { - Vectors.dense(probability.map(_.toDouble)).argmax + probability2prediction(Vectors.dense(probability)) } }