Make sure 'thresholds' are considered when executing predict method (#3577)
This commit is contained in:
parent
6288f6d563
commit
ce0f0568a6
@ -251,11 +251,11 @@ class XGBoostClassificationModel private[ml](
|
|||||||
override def predict(features: Vector): Double = {
|
override def predict(features: Vector): Double = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
|
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
|
||||||
val probability = _booster.predict(data = dm)(0)
|
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
||||||
if (numClasses == 2) {
|
if (numClasses == 2) {
|
||||||
math.round(probability(0))
|
math.round(probability(0))
|
||||||
} else {
|
} else {
|
||||||
Vectors.dense(probability.map(_.toDouble)).argmax
|
probability2prediction(Vectors.dense(probability))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user