fix spark tests on machines with many cores (#4634)
This commit is contained in:
parent
d333918f5e
commit
30204b50fe
@ -19,15 +19,16 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener}
|
import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener}
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||||
|
|
||||||
|
import scala.math.min
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||||
|
|
||||||
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
|
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
|
||||||
|
|
||||||
@transient private var currentSession: SparkSession = _
|
@transient private var currentSession: SparkSession = _
|
||||||
|
|
||||||
@ -35,7 +36,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
|||||||
implicit def sc: SparkContext = ss.sparkContext
|
implicit def sc: SparkContext = ss.sparkContext
|
||||||
|
|
||||||
protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
||||||
.master("local[*]")
|
.master(s"local[${numWorkers}]")
|
||||||
.appName("XGBoostSuite")
|
.appName("XGBoostSuite")
|
||||||
.config("spark.ui.enabled", false)
|
.config("spark.ui.enabled", false)
|
||||||
.config("spark.driver.memory", "512m")
|
.config("spark.driver.memory", "512m")
|
||||||
|
|||||||
@ -18,16 +18,17 @@ package org.apache.spark
|
|||||||
|
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest
|
import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.SparkSession
|
||||||
|
|
||||||
|
import scala.math.min
|
||||||
|
|
||||||
class SparkParallelismTrackerSuite extends FunSuite with PerTest {
|
class SparkParallelismTrackerSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
val numParallelism: Int = Runtime.getRuntime.availableProcessors()
|
val numParallelism: Int = min(Runtime.getRuntime.availableProcessors(), 4)
|
||||||
|
|
||||||
override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
||||||
.master("local[*]")
|
.master(s"local[${numParallelism}]")
|
||||||
.appName("XGBoostSuite")
|
.appName("XGBoostSuite")
|
||||||
.config("spark.ui.enabled", true)
|
.config("spark.ui.enabled", true)
|
||||||
.config("spark.driver.memory", "512m")
|
.config("spark.driver.memory", "512m")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user