[jvm-packages] Do not repartition when nWorker = 1 (#7676)
This commit is contained in:
parent
f08c5dcb06
commit
131858e7cb
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -397,11 +397,22 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
// No light cost way to get number of partitions from DataFrame, so always repartition
|
// No light cost way to get number of partitions from DataFrame, so always repartition
|
||||||
val newDF = colData.groupColName
|
val newDF = colData.groupColName
|
||||||
.map(gn => repartitionForGroup(gn, colData.rawDF, nWorkers))
|
.map(gn => repartitionForGroup(gn, colData.rawDF, nWorkers))
|
||||||
.getOrElse(colData.rawDF.repartition(nWorkers))
|
.getOrElse(repartitionInputData(colData.rawDF, nWorkers))
|
||||||
name -> ColumnDataBatch(newDF, colData.colIndices, colData.groupColName)
|
name -> ColumnDataBatch(newDF, colData.colIndices, colData.groupColName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def repartitionInputData(dataFrame: DataFrame, nWorkers: Int): DataFrame = {
|
||||||
|
// We can't check dataFrame.rdd.getNumPartitions == nWorkers here, since dataFrame.rdd is
|
||||||
|
// a lazy variable. If we call it here, we will not directly extract RDD[Table] again,
|
||||||
|
// instead, we will involve Columnar -> Row -> Columnar and decrease the performance
|
||||||
|
if (nWorkers == 1) {
|
||||||
|
dataFrame.coalesce(1)
|
||||||
|
} else {
|
||||||
|
dataFrame.repartition(nWorkers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private def repartitionForGroup(
|
private def repartitionForGroup(
|
||||||
groupName: String,
|
groupName: String,
|
||||||
dataFrame: DataFrame,
|
dataFrame: DataFrame,
|
||||||
@ -415,7 +426,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
|
|
||||||
implicit val encoder = RowEncoder(schema)
|
implicit val encoder = RowEncoder(schema)
|
||||||
// Expand the grouped rows after repartition
|
// Expand the grouped rows after repartition
|
||||||
groupedDF.repartition(nWorkers).mapPartitions(iter => {
|
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
|
||||||
new Iterator[Row] {
|
new Iterator[Row] {
|
||||||
var iterInRow: Iterator[Any] = Iterator.empty
|
var iterInRow: Iterator[Any] = Iterator.empty
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user