Clean the way deterministic paritioning is computed (#6033)
We propose to only use the rowHashCode to compute the partitionKey, adding the FeatureValue hashCode does not bring more value and would make the computation slower. Even though a collision would appear at 0.2% with MurmurHash3 this is bearable for partitioning, this won't have any impact on the data balancing.
This commit is contained in:
parent
c1ca872d1e
commit
ada964f16e
@ -70,50 +70,13 @@ object DataUtils extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private def featureValueOfDenseVector(rowHashCode: Int, features: DenseVector): Float = {
|
||||
val featureId = {
|
||||
if (rowHashCode > 0) {
|
||||
rowHashCode % features.size
|
||||
} else {
|
||||
// prevent overflow
|
||||
math.abs(rowHashCode + 1) % features.size
|
||||
}
|
||||
}
|
||||
features.values(featureId).toFloat
|
||||
}
|
||||
|
||||
private def featureValueOfSparseVector(rowHashCode: Int, features: SparseVector): Float = {
|
||||
val featureId = {
|
||||
if (rowHashCode > 0) {
|
||||
rowHashCode % features.indices.length
|
||||
} else {
|
||||
// prevent overflow
|
||||
math.abs(rowHashCode + 1) % features.indices.length
|
||||
}
|
||||
}
|
||||
features.values(featureId).toFloat
|
||||
}
|
||||
|
||||
private def calculatePartitionKey(row: Row, numPartitions: Int): Int = {
|
||||
val Row(_, features: Vector, _, _) = row
|
||||
val rowHashCode = row.hashCode()
|
||||
val featureValue = features match {
|
||||
case denseVector: DenseVector =>
|
||||
featureValueOfDenseVector(rowHashCode, denseVector)
|
||||
case sparseVector: SparseVector =>
|
||||
featureValueOfSparseVector(rowHashCode, sparseVector)
|
||||
}
|
||||
val nonNaNFeatureValue = if (featureValue.isNaN) { 0.0f } else { featureValue }
|
||||
math.abs((rowHashCode.toLong + nonNaNFeatureValue).toString.hashCode % numPartitions)
|
||||
}
|
||||
|
||||
private def attachPartitionKey(
|
||||
row: Row,
|
||||
deterministicPartition: Boolean,
|
||||
numWorkers: Int,
|
||||
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
|
||||
if (deterministicPartition) {
|
||||
(calculatePartitionKey(row, numWorkers), xgbLp)
|
||||
(math.abs(row.hashCode() % numWorkers), xgbLp)
|
||||
} else {
|
||||
(1, xgbLp)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user