[jvm-packages] Do not repartition when nWorker = 1 (#7676)
This commit is contained in:
parent
f08c5dcb06
commit
131858e7cb
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -397,11 +397,22 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
// No light cost way to get number of partitions from DataFrame, so always repartition
|
||||
val newDF = colData.groupColName
|
||||
.map(gn => repartitionForGroup(gn, colData.rawDF, nWorkers))
|
||||
.getOrElse(colData.rawDF.repartition(nWorkers))
|
||||
.getOrElse(repartitionInputData(colData.rawDF, nWorkers))
|
||||
name -> ColumnDataBatch(newDF, colData.colIndices, colData.groupColName)
|
||||
}
|
||||
}
|
||||
|
||||
private def repartitionInputData(dataFrame: DataFrame, nWorkers: Int): DataFrame = {
|
||||
// We can't check dataFrame.rdd.getNumPartitions == nWorkers here, since dataFrame.rdd is
|
||||
// a lazy variable. If we call it here, we will not directly extract RDD[Table] again,
|
||||
// instead, we will involve Columnar -> Row -> Columnar and decrease the performance
|
||||
if (nWorkers == 1) {
|
||||
dataFrame.coalesce(1)
|
||||
} else {
|
||||
dataFrame.repartition(nWorkers)
|
||||
}
|
||||
}
|
||||
|
||||
private def repartitionForGroup(
|
||||
groupName: String,
|
||||
dataFrame: DataFrame,
|
||||
@ -415,7 +426,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
|
||||
implicit val encoder = RowEncoder(schema)
|
||||
// Expand the grouped rows after repartition
|
||||
groupedDF.repartition(nWorkers).mapPartitions(iter => {
|
||||
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
|
||||
new Iterator[Row] {
|
||||
var iterInRow: Iterator[Any] = Iterator.empty
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user