[jvm-packages] fix the prediction issue for multi:softmax (#7694)
This commit is contained in:
parent
6762c45494
commit
89aa8ddf52
@ -385,11 +385,37 @@ class XGBoostClassificationModel private[ml](
|
|||||||
Vectors.dense(rawPredictions)
|
Vectors.dense(rawPredictions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if ($(rawPredictionCol).nonEmpty) {
|
||||||
|
outputData = outputData
|
||||||
|
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
|
||||||
|
numColsOutput += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getObjective.equals("multi:softmax")) {
|
||||||
|
// For objective=multi:softmax scenario, there is no probability predicted from xgboost.
|
||||||
|
// Instead, the probability column will be filled with real prediction
|
||||||
|
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||||
|
probability(0)
|
||||||
|
}
|
||||||
|
if ($(predictionCol).nonEmpty) {
|
||||||
|
outputData = outputData
|
||||||
|
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
||||||
|
numColsOutput += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||||
val prob = probability.map(_.toDouble).toArray
|
val prob = probability.map(_.toDouble).toArray
|
||||||
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
||||||
Vectors.dense(probabilities)
|
Vectors.dense(probabilities)
|
||||||
}
|
}
|
||||||
|
if ($(probabilityCol).nonEmpty) {
|
||||||
|
outputData = outputData
|
||||||
|
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
|
||||||
|
numColsOutput += 1
|
||||||
|
}
|
||||||
|
|
||||||
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||||
// From XGBoost probability to MLlib prediction
|
// From XGBoost probability to MLlib prediction
|
||||||
@ -397,24 +423,12 @@ class XGBoostClassificationModel private[ml](
|
|||||||
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
|
||||||
probability2prediction(Vectors.dense(probabilities))
|
probability2prediction(Vectors.dense(probabilities))
|
||||||
}
|
}
|
||||||
|
|
||||||
if ($(rawPredictionCol).nonEmpty) {
|
|
||||||
outputData = outputData
|
|
||||||
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
|
|
||||||
numColsOutput += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if ($(probabilityCol).nonEmpty) {
|
|
||||||
outputData = outputData
|
|
||||||
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
|
|
||||||
numColsOutput += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if ($(predictionCol).nonEmpty) {
|
if ($(predictionCol).nonEmpty) {
|
||||||
outputData = outputData
|
outputData = outputData
|
||||||
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
||||||
numColsOutput += 1
|
numColsOutput += 1
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (numColsOutput == 0) {
|
if (numColsOutput == 0) {
|
||||||
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -17,9 +17,11 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
|
||||||
import org.apache.spark.ml.linalg._
|
import org.apache.spark.ml.linalg._
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
import org.apache.spark.Partitioner
|
import org.apache.spark.Partitioner
|
||||||
|
|
||||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||||
@ -102,6 +104,8 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
assert(model.getEta == 0.1)
|
assert(model.getEta == 0.1)
|
||||||
assert(model.getMaxDepth == 6)
|
assert(model.getMaxDepth == 6)
|
||||||
assert(model.numClasses == 6)
|
assert(model.numClasses == 6)
|
||||||
|
val transformedDf = model.transform(trainingDF)
|
||||||
|
assert(!transformedDf.columns.contains("probability"))
|
||||||
}
|
}
|
||||||
|
|
||||||
test("use base margin") {
|
test("use base margin") {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user