diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 6b48a50f8..b35f5da34 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -415,7 +415,9 @@ class XGBoostClassificationModel private[ml]( var numColsOutput = 0 val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] => - Vectors.dense(rawPrediction.map(_.toDouble).toArray) + val raw = rawPrediction.map(_.toDouble).toArray + val rawPredictions = if (numClasses == 2) Array(-raw(0), raw(0)) else raw + Vectors.dense(rawPredictions) } val probabilityUDF = udf { probability: mutable.WrappedArray[Float] => diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 8b7bdb448..7645d2c89 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -60,10 +60,11 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap assert(testDF.count() === prediction4.size) + // the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark for (i <- prediction3.indices) { - assert(prediction3(i).length === prediction4(i).values.length) + assert(prediction3(i).length === prediction4(i).values.length - 1) for (j <- prediction3(i).indices) { - assert(prediction3(i)(j) === prediction4(i)(j)) + assert(prediction3(i)(j) === prediction4(i)(j + 1)) } }