From 2b7a1c5780e85835679ea0fc5203b3edd67fc081 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Mon, 13 Aug 2018 14:05:07 -0700 Subject: [PATCH] [jvm-packages] Avoid loosing precision when computing probabilities by converting to Double early (#3576) --- .../xgboost4j/scala/spark/XGBoostClassifier.scala | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 9773e3447..47b489c22 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -411,20 +411,15 @@ class XGBoostClassificationModel private[ml]( } val probabilityUDF = udf { probability: mutable.WrappedArray[Float] => - if (numClasses == 2) { - Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble)) - } else { - Vectors.dense(probability.map(_.toDouble).toArray) - } + val prob = probability.map(_.toDouble).toArray + val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob + Vectors.dense(probabilities) } val predictUDF = udf { probability: mutable.WrappedArray[Float] => // From XGBoost probability to MLlib prediction - val probabilities = if (numClasses == 2) { - Array(1 - probability(0), probability(0)).map(_.toDouble) - } else { - probability.map(_.toDouble).toArray - } + val prob = probability.map(_.toDouble).toArray + val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob probability2prediction(Vectors.dense(probabilities)) }