[jvm-packages] fix the prediction issue for multi:softmax (#7694)
This commit is contained in:
@@ -385,18 +385,7 @@ class XGBoostClassificationModel private[ml](
|
||||
Vectors.dense(rawPredictions)
|
||||
}
|
||||
|
||||
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||
val prob = probability.map(_.toDouble).toArray
|
||||
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
||||
Vectors.dense(probabilities)
|
||||
}
|
||||
|
||||
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||
// From XGBoost probability to MLlib prediction
|
||||
val prob = probability.map(_.toDouble).toArray
|
||||
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
||||
probability2prediction(Vectors.dense(probabilities))
|
||||
}
|
||||
|
||||
if ($(rawPredictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
@@ -404,16 +393,41 @@ class XGBoostClassificationModel private[ml](
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if ($(probabilityCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
if (getObjective.equals("multi:softmax")) {
|
||||
// For objective=multi:softmax scenario, there is no probability predicted from xgboost.
|
||||
// Instead, the probability column will be filled with real prediction
|
||||
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||
probability(0)
|
||||
}
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
||||
numColsOutput += 1
|
||||
} else {
|
||||
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||
val prob = probability.map(_.toDouble).toArray
|
||||
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
||||
Vectors.dense(probabilities)
|
||||
}
|
||||
if ($(probabilityCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||
// From XGBoost probability to MLlib prediction
|
||||
val prob = probability.map(_.toDouble).toArray
|
||||
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
||||
probability2prediction(Vectors.dense(probabilities))
|
||||
}
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
}
|
||||
|
||||
if (numColsOutput == 0) {
|
||||
|
||||
Reference in New Issue
Block a user