[jvm-package] remove the coalesce in barrier mode (#7846)
This commit is contained in:
parent
fdf533f2b9
commit
686caad40c
@ -69,7 +69,7 @@ public class BoosterTest {
|
||||
.hasHeader().build();
|
||||
|
||||
int maxBin = 16;
|
||||
int round = 100;
|
||||
int round = 10;
|
||||
//set params
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
|
||||
@ -407,15 +407,10 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
}
|
||||
|
||||
private def repartitionInputData(dataFrame: DataFrame, nWorkers: Int): DataFrame = {
|
||||
// We can't check dataFrame.rdd.getNumPartitions == nWorkers here, since dataFrame.rdd is
|
||||
// a lazy variable. If we call it here, we will not directly extract RDD[Table] again,
|
||||
// instead, we will involve Columnar -> Row -> Columnar and decrease the performance
|
||||
if (nWorkers == 1) {
|
||||
dataFrame.coalesce(1)
|
||||
} else {
|
||||
// we can't involve any coalesce operation here, since Barrier mode will check
|
||||
// the RDD patterns which does not allow coalesce.
|
||||
dataFrame.repartition(nWorkers)
|
||||
}
|
||||
}
|
||||
|
||||
private def repartitionForGroup(
|
||||
groupName: String,
|
||||
|
||||
@ -39,13 +39,8 @@ trait GpuTestSuite extends FunSuite with TmpFolderSuite {
|
||||
|
||||
def enableCsvConf(): SparkConf = {
|
||||
new SparkConf()
|
||||
.set(RapidsConf.ENABLE_READ_CSV_DATES.key, "true")
|
||||
.set(RapidsConf.ENABLE_READ_CSV_BYTES.key, "true")
|
||||
.set(RapidsConf.ENABLE_READ_CSV_SHORTS.key, "true")
|
||||
.set(RapidsConf.ENABLE_READ_CSV_INTEGERS.key, "true")
|
||||
.set(RapidsConf.ENABLE_READ_CSV_LONGS.key, "true")
|
||||
.set(RapidsConf.ENABLE_READ_CSV_FLOATS.key, "true")
|
||||
.set(RapidsConf.ENABLE_READ_CSV_DOUBLES.key, "true")
|
||||
.set("spark.rapids.sql.csv.read.float.enabled", "true")
|
||||
.set("spark.rapids.sql.csv.read.double.enabled", "true")
|
||||
}
|
||||
|
||||
def withGpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
|
||||
@ -246,12 +241,13 @@ object SparkSessionHolder extends Logging {
|
||||
Locale.setDefault(Locale.US)
|
||||
|
||||
val builder = SparkSession.builder()
|
||||
.master("local[1]")
|
||||
.master("local[2]")
|
||||
.config("spark.sql.adaptive.enabled", "false")
|
||||
.config("spark.rapids.sql.enabled", "false")
|
||||
.config("spark.rapids.sql.test.enabled", "false")
|
||||
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
|
||||
.config("spark.rapids.memory.gpu.pooling.enabled", "false") // Disable RMM for unit tests.
|
||||
.config("spark.sql.files.maxPartitionBytes", "1000")
|
||||
.appName("XGBoost4j-Spark-Gpu unit test")
|
||||
|
||||
builder.getOrCreate()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user