[jvm-packages] Do not repartition when nWorker = 1 (#7676)

This commit is contained in:
Bobby Wang 2022-02-19 21:45:54 +08:00 committed by GitHub
parent f08c5dcb06
commit 131858e7cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021 by Contributors
Copyright (c) 2021-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -397,11 +397,22 @@ object GpuPreXGBoost extends PreXGBoostProvider {
// No light cost way to get number of partitions from DataFrame, so always repartition
val newDF = colData.groupColName
.map(gn => repartitionForGroup(gn, colData.rawDF, nWorkers))
.getOrElse(colData.rawDF.repartition(nWorkers))
.getOrElse(repartitionInputData(colData.rawDF, nWorkers))
name -> ColumnDataBatch(newDF, colData.colIndices, colData.groupColName)
}
}
private def repartitionInputData(dataFrame: DataFrame, nWorkers: Int): DataFrame = {
// We can't check dataFrame.rdd.getNumPartitions == nWorkers here, since dataFrame.rdd is
// a lazy variable. If we call it here, we will not directly extract RDD[Table] again,
// instead, we will involve Columnar -> Row -> Columnar and decrease the performance
if (nWorkers == 1) {
dataFrame.coalesce(1)
} else {
dataFrame.repartition(nWorkers)
}
}
private def repartitionForGroup(
groupName: String,
dataFrame: DataFrame,
@ -415,7 +426,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
implicit val encoder = RowEncoder(schema)
// Expand the grouped rows after repartition
groupedDF.repartition(nWorkers).mapPartitions(iter => {
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
new Iterator[Row] {
var iterInRow: Iterator[Any] = Iterator.empty