Make sure 'thresholds' are considered when executing predict method (#3577)

This commit is contained in:
Matthew Tovbin 2018-08-13 14:04:47 -07:00 committed by Nan Zhu
parent 6288f6d563
commit ce0f0568a6

View File

@ -251,11 +251,11 @@ class XGBoostClassificationModel private[ml](
override def predict(features: Vector): Double = { override def predict(features: Vector): Double = {
import DataUtils._ import DataUtils._
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing))) val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
val probability = _booster.predict(data = dm)(0) val probability = _booster.predict(data = dm)(0).map(_.toDouble)
if (numClasses == 2) { if (numClasses == 2) {
math.round(probability(0)) math.round(probability(0))
} else { } else {
Vectors.dense(probability.map(_.toDouble)).argmax probability2prediction(Vectors.dense(probability))
} }
} }