[jvm-packages] Fix vector size of 'rawPredictionCol' in XGBoostClassificationModel (#3932)
* Fix vector size of 'rawPredictionCol' in XGBoostClassificationModel * Fix UT
This commit is contained in:
parent
f9302a56fb
commit
42cac4a30b
@ -415,7 +415,9 @@ class XGBoostClassificationModel private[ml](
|
|||||||
var numColsOutput = 0
|
var numColsOutput = 0
|
||||||
|
|
||||||
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
|
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
|
||||||
Vectors.dense(rawPrediction.map(_.toDouble).toArray)
|
val raw = rawPrediction.map(_.toDouble).toArray
|
||||||
|
val rawPredictions = if (numClasses == 2) Array(-raw(0), raw(0)) else raw
|
||||||
|
Vectors.dense(rawPredictions)
|
||||||
}
|
}
|
||||||
|
|
||||||
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||||
|
|||||||
@ -60,10 +60,11 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
|
||||||
|
|
||||||
assert(testDF.count() === prediction4.size)
|
assert(testDF.count() === prediction4.size)
|
||||||
|
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
|
||||||
for (i <- prediction3.indices) {
|
for (i <- prediction3.indices) {
|
||||||
assert(prediction3(i).length === prediction4(i).values.length)
|
assert(prediction3(i).length === prediction4(i).values.length - 1)
|
||||||
for (j <- prediction3(i).indices) {
|
for (j <- prediction3(i).indices) {
|
||||||
assert(prediction3(i)(j) === prediction4(i)(j))
|
assert(prediction3(i)(j) === prediction4(i)(j + 1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user