[jvm-packages] Fix vector size of 'rawPredictionCol' in XGBoostClassificationModel (#3932)

* Fix vector size of 'rawPredictionCol' in XGBoostClassificationModel

* Fix UT
This commit is contained in:
Huafeng Wang
2018-11-24 13:09:43 +08:00
committed by Nan Zhu
parent f9302a56fb
commit 42cac4a30b
2 changed files with 6 additions and 3 deletions

View File

@@ -415,7 +415,9 @@ class XGBoostClassificationModel private[ml](
var numColsOutput = 0
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
Vectors.dense(rawPrediction.map(_.toDouble).toArray)
val raw = rawPrediction.map(_.toDouble).toArray
val rawPredictions = if (numClasses == 2) Array(-raw(0), raw(0)) else raw
Vectors.dense(rawPredictions)
}
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>