[jvm-packages] fix the prediction issue for multi:softmax (#7694)

This commit is contained in:
Bobby Wang 2022-02-24 01:09:45 +08:00 committed by GitHub
parent 6762c45494
commit 89aa8ddf52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 21 deletions

View File

@ -385,11 +385,37 @@ class XGBoostClassificationModel private[ml](
Vectors.dense(rawPredictions)
}
if ($(rawPredictionCol).nonEmpty) {
outputData = outputData
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
numColsOutput += 1
}
if (getObjective.equals("multi:softmax")) {
// For objective=multi:softmax scenario, there is no probability predicted from xgboost.
// Instead, the probability column will be filled with real prediction
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
probability(0)
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
numColsOutput += 1
}
} else {
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
Vectors.dense(probabilities)
}
if ($(probabilityCol).nonEmpty) {
outputData = outputData
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
numColsOutput += 1
}
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
// From XGBoost probability to MLlib prediction
@ -397,24 +423,12 @@ class XGBoostClassificationModel private[ml](
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
probability2prediction(Vectors.dense(probabilities))
}
if ($(rawPredictionCol).nonEmpty) {
outputData = outputData
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
numColsOutput += 1
}
if ($(probabilityCol).nonEmpty) {
outputData = outputData
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
numColsOutput += 1
}
}
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -17,9 +17,11 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.apache.spark.Partitioner
class XGBoostClassifierSuite extends FunSuite with PerTest {
@ -102,6 +104,8 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
assert(model.getEta == 0.1)
assert(model.getMaxDepth == 6)
assert(model.numClasses == 6)
val transformedDf = model.transform(trainingDF)
assert(!transformedDf.columns.contains("probability"))
}
test("use base margin") {