diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala index df787d8eb..15ffe4c06 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala @@ -103,7 +103,8 @@ object DataUtils extends Serializable { case sparseVector: SparseVector => featureValueOfSparseVector(rowHashCode, sparseVector) } - math.abs((rowHashCode.toLong + featureValue).toString.hashCode % numPartitions) + val nonNaNFeatureValue = if (featureValue.isNaN) { 0.0f } else { featureValue } + math.abs((rowHashCode.toLong + nonNaNFeatureValue).toString.hashCode % numPartitions) } private def attachPartitionKey( diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala index 986b0843b..ff0492f41 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala @@ -16,6 +16,7 @@ package ml.dmlc.xgboost4j.scala.spark +import org.apache.spark.ml.linalg.Vectors import org.scalatest.FunSuite import org.apache.spark.sql.functions._ @@ -79,4 +80,34 @@ class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite wit map2 } } + + test("deterministic partitioning has a uniform repartition on dataset with missing values") { + val N = 10000 + val dataset = (0 until N).map{ n => + (n, n % 2, Vectors.sparse(3, Array(0, 1, 2), Array(Double.NaN, n, Double.NaN))) + } + + val df = ss.createDataFrame(sc.parallelize(dataset)).toDF("id", "label", "features") + + val dfRepartitioned = DataUtils.convertDataFrameToXGBLabeledPointRDDs( + col("label"), + col("features"), + lit(1.0), + lit(Float.NaN), + None, + 10, + deterministicPartition = true, + df + ).head + + val partitionsSizes = dfRepartitioned + .mapPartitions(iter => Array(iter.size.toDouble).iterator, true) + .collect() + val partitionMean = partitionsSizes.sum / partitionsSizes.length + val squaredDiffSum = partitionsSizes + .map(partitionSize => Math.pow(partitionSize - partitionMean, 2)) + val standardDeviation = math.sqrt(squaredDiffSum.sum / squaredDiffSum.length) + + assert(standardDeviation < math.sqrt(N.toDouble)) + } }