diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala index b6399d58c..0c3521069 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021 by Contributors + Copyright (c) 2021-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -397,11 +397,22 @@ object GpuPreXGBoost extends PreXGBoostProvider { // No light cost way to get number of partitions from DataFrame, so always repartition val newDF = colData.groupColName .map(gn => repartitionForGroup(gn, colData.rawDF, nWorkers)) - .getOrElse(colData.rawDF.repartition(nWorkers)) + .getOrElse(repartitionInputData(colData.rawDF, nWorkers)) name -> ColumnDataBatch(newDF, colData.colIndices, colData.groupColName) } } + private def repartitionInputData(dataFrame: DataFrame, nWorkers: Int): DataFrame = { + // We can't check dataFrame.rdd.getNumPartitions == nWorkers here, since dataFrame.rdd is + // a lazy variable. If we call it here, we will not directly extract RDD[Table] again, + // instead, we will involve Columnar -> Row -> Columnar and decrease the performance + if (nWorkers == 1) { + dataFrame.coalesce(1) + } else { + dataFrame.repartition(nWorkers) + } + } + private def repartitionForGroup( groupName: String, dataFrame: DataFrame, @@ -415,7 +426,7 @@ object GpuPreXGBoost extends PreXGBoostProvider { implicit val encoder = RowEncoder(schema) // Expand the grouped rows after repartition - groupedDF.repartition(nWorkers).mapPartitions(iter => { + repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => { new Iterator[Row] { var iterInRow: Iterator[Any] = Iterator.empty