diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 34e59b0dd..65f0ef30f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -70,30 +70,53 @@ private[spark] case class XGBLabeledPointGroup( object XGBoost extends Serializable { private val logger = LogFactory.getLog("XGBoostSpark") - private[spark] def removeMissingValues( - xgbLabelPoints: Iterator[XGBLabeledPoint], - missing: Float): Iterator[XGBLabeledPoint] = { - if (!missing.isNaN) { - xgbLabelPoints.map { labeledPoint => - val indicesBuilder = new mutable.ArrayBuilder.ofInt() - val valuesBuilder = new mutable.ArrayBuilder.ofFloat() - for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) { - indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i)) - valuesBuilder += value + private def verifyMissingSetting(xgbLabelPoints: Iterator[XGBLabeledPoint], missing: Float): + Iterator[XGBLabeledPoint] = { + if (missing != 0.0f) { + xgbLabelPoints.map(labeledPoint => { + if (labeledPoint.indices != null) { + throw new RuntimeException("you can only specify missing value as 0.0 when you have" + + " SparseVector as your feature format") } - labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result()) - } + labeledPoint + }) } else { xgbLabelPoints } } - private def removeMissingValuesWithGroup( + private def removeMissingValues( + xgbLabelPoints: Iterator[XGBLabeledPoint], + missing: Float, + keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = { + xgbLabelPoints.map { labeledPoint => + val indicesBuilder = new mutable.ArrayBuilder.ofInt() + val valuesBuilder = new mutable.ArrayBuilder.ofFloat() + for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) { + indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i)) + valuesBuilder += value + } + labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result()) + } + } + + private[spark] def processMissingValues( + xgbLabelPoints: Iterator[XGBLabeledPoint], + missing: Float): Iterator[XGBLabeledPoint] = { + if (!missing.isNaN) { + removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing), + missing, (v: Float) => v != missing) + } else { + removeMissingValues(xgbLabelPoints, missing, (v: Float) => !v.isNaN) + } + } + + private def processMissingValuesWithGroup( xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]], missing: Float): Iterator[Array[XGBLabeledPoint]] = { if (!missing.isNaN) { xgbLabelPointGroups.map { - labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray + labeledPoints => XGBoost.processMissingValues(labeledPoints.iterator, missing).toArray } } else { xgbLabelPointGroups @@ -310,7 +333,7 @@ object XGBoost extends Serializable { if (evalSetsMap.isEmpty) { trainingData.mapPartitions(labeledPoints => { val watches = Watches.buildWatches(params, - removeMissingValues(labeledPoints, missing), + processMissingValues(labeledPoints, missing), getCacheDirName(useExternalMemory)) buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster) @@ -320,7 +343,7 @@ object XGBoost extends Serializable { nameAndLabeledPointSets => val watches = Watches.buildWatches( nameAndLabeledPointSets.map { - case (name, iter) => (name, removeMissingValues(iter, missing))}, + case (name, iter) => (name, processMissingValues(iter, missing))}, getCacheDirName(useExternalMemory)) buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster) @@ -340,7 +363,7 @@ object XGBoost extends Serializable { if (evalSetsMap.isEmpty) { trainingData.mapPartitions(labeledPointGroups => { val watches = Watches.buildWatchesWithGroup(params, - removeMissingValuesWithGroup(labeledPointGroups, missing), + processMissingValuesWithGroup(labeledPointGroups, missing), getCacheDirName(useExternalMemory)) buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster) }).cache() @@ -349,7 +372,7 @@ object XGBoost extends Serializable { labeledPointGroupSets => { val watches = Watches.buildWatchesWithGroup( labeledPointGroupSets.map { - case (name, iter) => (name, removeMissingValuesWithGroup(iter, missing)) + case (name, iter) => (name, processMissingValuesWithGroup(iter, missing)) }, getCacheDirName(useExternalMemory)) buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 366f871c2..bda9189b7 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -256,7 +256,7 @@ class XGBoostClassificationModel private[ml]( */ override def predict(features: Vector): Double = { import DataUtils._ - val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing))) + val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing))) val probability = _booster.predict(data = dm)(0).map(_.toDouble) if (numClasses == 2) { math.round(probability(0)) @@ -303,7 +303,7 @@ class XGBoostClassificationModel private[ml]( } } val dm = new DMatrix( - XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)), + XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)), cacheInfo) try { val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 0abad8b9c..b47bca27b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -247,7 +247,7 @@ class XGBoostRegressionModel private[ml] ( */ override def predict(features: Vector): Double = { import DataUtils._ - val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing))) + val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing))) _booster.predict(data = dm)(0)(0) } @@ -275,7 +275,7 @@ class XGBoostRegressionModel private[ml] ( } } val dm = new DMatrix( - XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)), + XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)), cacheInfo) try { val Array(originalPredictionItr, predLeafItr, predContribItr) = diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 09b5a8883..50f827c34 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -33,6 +33,8 @@ import scala.util.Random import ml.dmlc.xgboost4j.java.Rabit +import org.apache.spark.ml.feature.VectorAssembler + class XGBoostGeneralSuite extends FunSuite with PerTest { test("test Rabit allreduce to validate Scala-implemented Rabit tracker") { @@ -227,26 +229,45 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { def buildDenseDataFrame(): DataFrame = { val numRows = 100 val numCols = 5 - val data = (0 until numRows).map { x => val label = Random.nextInt(2) val values = Array.tabulate[Double](numCols) { c => - if (c == numCols - 1) -0.1 else Random.nextDouble + if (c == numCols - 1) 0 else Random.nextDouble } - (label, Vectors.dense(values)) } - ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features") } - val denseDF = buildDenseDataFrame().repartition(4) val paramMap = List("eta" -> "1", "max_depth" -> "2", - "objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap + "objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap val model = new XGBoostClassifier(paramMap).fit(denseDF) model.transform(denseDF).collect() } + test("handle Float.NaN as missing value correctly") { + val spark = ss + import spark.implicits._ + val testDF = Seq( + (1.0f, 0.0f, Float.NaN, 1.0), + (1.0f, 0.0f, 1.0f, 1.0), + (0.0f, 1.0f, 0.0f, 0.0), + (1.0f, 0.0f, 1.0f, 1.0), + (1.0f, Float.NaN, 0.0f, 0.0), + (0.0f, 0.0f, 0.0f, 0.0), + (0.0f, 1.0f, 0.0f, 1.0), + (Float.NaN, 0.0f, 0.0f, 1.0) + ).toDF("col1", "col2", "col3", "label") + val vectorAssembler = new VectorAssembler() + .setInputCols(Array("col1", "col2", "col3")) + .setOutputCol("features") + val inputDF = vectorAssembler.transform(testDF).select("features", "label") + val paramMap = List("eta" -> "1", "max_depth" -> "2", + "objective" -> "binary:logistic", "num_workers" -> 1).toMap + val model = new XGBoostClassifier(paramMap).fit(inputDF) + model.transform(inputDF).collect() + } + test("training with spark parallelism checks disabled") { val eval = new EvalError() val training = buildDataFrame(Classification.train)