Make sure 'thresholds' are considered when executing predict method (#3577)

This commit is contained in:
Matthew Tovbin 2018-08-13 14:04:47 -07:00 committed by Nan Zhu
parent 6288f6d563
commit ce0f0568a6

View File

@ -251,11 +251,11 @@ class XGBoostClassificationModel private[ml](
override def predict(features: Vector): Double = {
import DataUtils._
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
val probability = _booster.predict(data = dm)(0)
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
if (numClasses == 2) {
math.round(probability(0))
} else {
Vectors.dense(probability.map(_.toDouble)).argmax
probability2prediction(Vectors.dense(probability))
}
}