[jvm-packages] fix the prediction issue for multi:softmax (#7694)

This commit is contained in:
Bobby Wang
2022-02-24 01:09:45 +08:00
committed by GitHub
parent 6762c45494
commit 89aa8ddf52
2 changed files with 39 additions and 21 deletions

View File

@@ -385,18 +385,7 @@ class XGBoostClassificationModel private[ml](
Vectors.dense(rawPredictions)
}
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
Vectors.dense(probabilities)
}
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
// From XGBoost probability to MLlib prediction
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
probability2prediction(Vectors.dense(probabilities))
}
if ($(rawPredictionCol).nonEmpty) {
outputData = outputData
@@ -404,16 +393,41 @@ class XGBoostClassificationModel private[ml](
numColsOutput += 1
}
if ($(probabilityCol).nonEmpty) {
outputData = outputData
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
numColsOutput += 1
}
if (getObjective.equals("multi:softmax")) {
// For objective=multi:softmax scenario, there is no probability predicted from xgboost.
// Instead, the probability column will be filled with real prediction
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
probability(0)
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
numColsOutput += 1
} else {
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
Vectors.dense(probabilities)
}
if ($(probabilityCol).nonEmpty) {
outputData = outputData
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
numColsOutput += 1
}
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
// From XGBoost probability to MLlib prediction
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
probability2prediction(Vectors.dense(probabilities))
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
numColsOutput += 1
}
}
if (numColsOutput == 0) {