[jvm-packages] Fix vector size of 'rawPredictionCol' in XGBoostClassificationModel (#3932)
* Fix vector size of 'rawPredictionCol' in XGBoostClassificationModel * Fix UT
This commit is contained in:
@@ -415,7 +415,9 @@ class XGBoostClassificationModel private[ml](
|
||||
var numColsOutput = 0
|
||||
|
||||
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
|
||||
Vectors.dense(rawPrediction.map(_.toDouble).toArray)
|
||||
val raw = rawPrediction.map(_.toDouble).toArray
|
||||
val rawPredictions = if (numClasses == 2) Array(-raw(0), raw(0)) else raw
|
||||
Vectors.dense(rawPredictions)
|
||||
}
|
||||
|
||||
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||
|
||||
Reference in New Issue
Block a user