fix spark tests on machines with many cores (#4634)

This commit is contained in:
Rong Ou 2019-07-07 16:02:56 -07:00 committed by Nan Zhu
parent d333918f5e
commit 30204b50fe
2 changed files with 8 additions and 6 deletions

View File

@ -19,15 +19,16 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener} import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener}
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.scalatest.{BeforeAndAfterEach, FunSuite}
import scala.math.min
import scala.util.Random import scala.util.Random
trait PerTest extends BeforeAndAfterEach { self: FunSuite => trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors() protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
@transient private var currentSession: SparkSession = _ @transient private var currentSession: SparkSession = _
@ -35,7 +36,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
implicit def sc: SparkContext = ss.sparkContext implicit def sc: SparkContext = ss.sparkContext
protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder() protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
.master("local[*]") .master(s"local[${numWorkers}]")
.appName("XGBoostSuite") .appName("XGBoostSuite")
.config("spark.ui.enabled", false) .config("spark.ui.enabled", false)
.config("spark.driver.memory", "512m") .config("spark.driver.memory", "512m")

View File

@ -18,16 +18,17 @@ package org.apache.spark
import org.scalatest.FunSuite import org.scalatest.FunSuite
import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSession
import scala.math.min
class SparkParallelismTrackerSuite extends FunSuite with PerTest { class SparkParallelismTrackerSuite extends FunSuite with PerTest {
val numParallelism: Int = Runtime.getRuntime.availableProcessors() val numParallelism: Int = min(Runtime.getRuntime.availableProcessors(), 4)
override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder() override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
.master("local[*]") .master(s"local[${numParallelism}]")
.appName("XGBoostSuite") .appName("XGBoostSuite")
.config("spark.ui.enabled", true) .config("spark.ui.enabled", true)
.config("spark.driver.memory", "512m") .config("spark.driver.memory", "512m")