Clean the way deterministic paritioning is computed (#6033)

We propose to only use the rowHashCode to compute the partitionKey, adding the FeatureValue hashCode does not bring more value and would make the computation slower. Even though a collision would appear at 0.2% with MurmurHash3 this is bearable for partitioning, this won't have any impact on the data balancing.
This commit is contained in:
Anthony D'Amato 2020-08-30 23:38:23 +02:00 committed by GitHub
parent c1ca872d1e
commit ada964f16e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -70,50 +70,13 @@ object DataUtils extends Serializable {
} }
} }
private def featureValueOfDenseVector(rowHashCode: Int, features: DenseVector): Float = {
val featureId = {
if (rowHashCode > 0) {
rowHashCode % features.size
} else {
// prevent overflow
math.abs(rowHashCode + 1) % features.size
}
}
features.values(featureId).toFloat
}
private def featureValueOfSparseVector(rowHashCode: Int, features: SparseVector): Float = {
val featureId = {
if (rowHashCode > 0) {
rowHashCode % features.indices.length
} else {
// prevent overflow
math.abs(rowHashCode + 1) % features.indices.length
}
}
features.values(featureId).toFloat
}
private def calculatePartitionKey(row: Row, numPartitions: Int): Int = {
val Row(_, features: Vector, _, _) = row
val rowHashCode = row.hashCode()
val featureValue = features match {
case denseVector: DenseVector =>
featureValueOfDenseVector(rowHashCode, denseVector)
case sparseVector: SparseVector =>
featureValueOfSparseVector(rowHashCode, sparseVector)
}
val nonNaNFeatureValue = if (featureValue.isNaN) { 0.0f } else { featureValue }
math.abs((rowHashCode.toLong + nonNaNFeatureValue).toString.hashCode % numPartitions)
}
private def attachPartitionKey( private def attachPartitionKey(
row: Row, row: Row,
deterministicPartition: Boolean, deterministicPartition: Boolean,
numWorkers: Int, numWorkers: Int,
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = { xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
if (deterministicPartition) { if (deterministicPartition) {
(calculatePartitionKey(row, numWorkers), xgbLp) (math.abs(row.hashCode() % numWorkers), xgbLp)
} else { } else {
(1, xgbLp) (1, xgbLp)
} }