diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 4d03b309c..c8635d93c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -385,18 +385,7 @@ class XGBoostClassificationModel private[ml]( Vectors.dense(rawPredictions) } - val probabilityUDF = udf { probability: mutable.WrappedArray[Float] => - val prob = probability.map(_.toDouble).toArray - val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob - Vectors.dense(probabilities) - } - val predictUDF = udf { probability: mutable.WrappedArray[Float] => - // From XGBoost probability to MLlib prediction - val prob = probability.map(_.toDouble).toArray - val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob - probability2prediction(Vectors.dense(probabilities)) - } if ($(rawPredictionCol).nonEmpty) { outputData = outputData @@ -404,16 +393,41 @@ class XGBoostClassificationModel private[ml]( numColsOutput += 1 } - if ($(probabilityCol).nonEmpty) { - outputData = outputData - .withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol))) - numColsOutput += 1 - } + if (getObjective.equals("multi:softmax")) { + // For objective=multi:softmax scenario, there is no probability predicted from xgboost. + // Instead, the probability column will be filled with real prediction + val predictUDF = udf { probability: mutable.WrappedArray[Float] => + probability(0) + } + if ($(predictionCol).nonEmpty) { + outputData = outputData + .withColumn($(predictionCol), predictUDF(col(_probabilityCol))) + numColsOutput += 1 + } - if ($(predictionCol).nonEmpty) { - outputData = outputData - .withColumn($(predictionCol), predictUDF(col(_probabilityCol))) - numColsOutput += 1 + } else { + val probabilityUDF = udf { probability: mutable.WrappedArray[Float] => + val prob = probability.map(_.toDouble).toArray + val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob + Vectors.dense(probabilities) + } + if ($(probabilityCol).nonEmpty) { + outputData = outputData + .withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol))) + numColsOutput += 1 + } + + val predictUDF = udf { probability: mutable.WrappedArray[Float] => + // From XGBoost probability to MLlib prediction + val prob = probability.map(_.toDouble).toArray + val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob + probability2prediction(Vectors.dense(probabilities)) + } + if ($(predictionCol).nonEmpty) { + outputData = outputData + .withColumn($(predictionCol), predictUDF(col(_probabilityCol))) + numColsOutput += 1 + } } if (numColsOutput == 0) { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index c0c2988f6..7940a51e5 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,9 +17,11 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} + import org.apache.spark.ml.linalg._ import org.apache.spark.sql._ import org.scalatest.FunSuite + import org.apache.spark.Partitioner class XGBoostClassifierSuite extends FunSuite with PerTest { @@ -102,6 +104,8 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { assert(model.getEta == 0.1) assert(model.getMaxDepth == 6) assert(model.numClasses == 6) + val transformedDf = model.transform(trainingDF) + assert(!transformedDf.columns.contains("probability")) } test("use base margin") {