diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index b16142ae0..341db97bc 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -19,15 +19,16 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} - import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener} import org.apache.spark.sql._ import org.scalatest.{BeforeAndAfterEach, FunSuite} + +import scala.math.min import scala.util.Random trait PerTest extends BeforeAndAfterEach { self: FunSuite => - protected val numWorkers: Int = Runtime.getRuntime.availableProcessors() + protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4) @transient private var currentSession: SparkSession = _ @@ -35,7 +36,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => implicit def sc: SparkContext = ss.sparkContext protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder() - .master("local[*]") + .master(s"local[${numWorkers}]") .appName("XGBoostSuite") .config("spark.ui.enabled", false) .config("spark.driver.memory", "512m") diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index ba3b15338..7f344674f 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -18,16 +18,17 @@ package org.apache.spark import org.scalatest.FunSuite import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest - import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession +import scala.math.min + class SparkParallelismTrackerSuite extends FunSuite with PerTest { - val numParallelism: Int = Runtime.getRuntime.availableProcessors() + val numParallelism: Int = min(Runtime.getRuntime.availableProcessors(), 4) override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder() - .master("local[*]") + .master(s"local[${numParallelism}]") .appName("XGBoostSuite") .config("spark.ui.enabled", true) .config("spark.driver.memory", "512m")