Make sure 'thresholds' are considered when executing predict method (#3577)
This commit is contained in:
parent
6288f6d563
commit
ce0f0568a6
@ -251,11 +251,11 @@ class XGBoostClassificationModel private[ml](
|
||||
override def predict(features: Vector): Double = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
|
||||
val probability = _booster.predict(data = dm)(0)
|
||||
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
||||
if (numClasses == 2) {
|
||||
math.round(probability(0))
|
||||
} else {
|
||||
Vectors.dense(probability.map(_.toDouble)).argmax
|
||||
probability2prediction(Vectors.dense(probability))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user