From b546321c834302f72a96291e042e46556d86c6ae Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Mon, 30 Jul 2018 14:16:24 -0700 Subject: [PATCH] [jvm-packages] the current version of xgboost does not consider missing value in prediction (#3529) * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * consider missing value in prediction * handle single prediction instance * fix type conversion --- .../scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 8 ++++---- .../dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala | 10 ++++++---- .../dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala | 10 ++++++---- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index d2febf61a..a419b12c5 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -55,11 +55,11 @@ object TrackerConf { object XGBoost extends Serializable { private val logger = LogFactory.getLog("XGBoostSpark") - private def removeMissingValues( - denseLabeledPoints: Iterator[XGBLabeledPoint], + private[spark] def removeMissingValues( + xgbLabelPoints: Iterator[XGBLabeledPoint], missing: Float): Iterator[XGBLabeledPoint] = { if (!missing.isNaN) { - denseLabeledPoints.map { labeledPoint => + xgbLabelPoints.map { labeledPoint => val indicesBuilder = new mutable.ArrayBuilder.ofInt() val valuesBuilder = new mutable.ArrayBuilder.ofFloat() for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) { @@ -69,7 +69,7 @@ object XGBoost extends Serializable { labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result()) } } else { - denseLabeledPoints + xgbLabelPoints } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 85f9b3eab..5f81393e0 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -241,7 +241,7 @@ class XGBoostClassificationModel private[ml]( */ override def predict(features: Vector): Double = { import DataUtils._ - val dm = new DMatrix(Iterator(features.asXGB)) + val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing))) val probability = _booster.predict(data = dm)(0) if (numClasses == 2) { math.round(probability(0)) @@ -272,12 +272,12 @@ class XGBoostClassificationModel private[ml]( val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val appName = dataset.sparkSession.sparkContext.appName - val rdd = dataset.rdd.mapPartitions { rowIterator => + val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => if (rowIterator.hasNext) { val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap Rabit.init(rabitEnv.asJava) val (rowItr1, rowItr2) = rowIterator.duplicate - val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector]( + val featuresIterator = rowItr2.map(row => row.getAs[Vector]( $(featuresCol))).toList.iterator import DataUtils._ val cacheInfo = { @@ -288,7 +288,9 @@ class XGBoostClassificationModel private[ml]( } } - val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo) + val dm = new DMatrix( + XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)), + cacheInfo) try { val rawPredictionItr = { bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 6b6c635bd..2d6568300 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -237,7 +237,7 @@ class XGBoostRegressionModel private[ml] ( */ override def predict(features: Vector): Double = { import DataUtils._ - val dm = new DMatrix(Iterator(features.asXGB)) + val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing))) _booster.predict(data = dm)(0)(0) } @@ -250,12 +250,12 @@ class XGBoostRegressionModel private[ml] ( val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val appName = dataset.sparkSession.sparkContext.appName - val rdd = dataset.rdd.mapPartitions { rowIterator => + val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => if (rowIterator.hasNext) { val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap Rabit.init(rabitEnv.asJava) val (rowItr1, rowItr2) = rowIterator.duplicate - val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector]( + val featuresIterator = rowItr2.map(row => row.getAs[Vector]( $(featuresCol))).toList.iterator import DataUtils._ val cacheInfo = { @@ -266,7 +266,9 @@ class XGBoostRegressionModel private[ml] ( } } - val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo) + val dm = new DMatrix( + XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)), + cacheInfo) try { val originalPredictionItr = { bBooster.value.predict(dm).map(Row(_)).iterator