[jvm-packages] consider spark.task.cpus when controlling parallelism (#3530)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * consider spark.task.cpus when controlling parallelism * fix bug * fix conf setup * calculate requestedCores within ParallelismController * enforce spark.task.cpus = 1 * unify unit test case framework * enable spark ui
This commit is contained in:
@@ -19,7 +19,8 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.spark.SparkContext
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||
|
||||
@@ -37,6 +38,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
.appName("XGBoostSuite")
|
||||
.config("spark.ui.enabled", false)
|
||||
.config("spark.driver.memory", "512m")
|
||||
.config("spark.task.cpus", 1)
|
||||
|
||||
override def beforeEach(): Unit = getOrCreateSession
|
||||
|
||||
|
||||
@@ -16,22 +16,24 @@
|
||||
|
||||
package org.apache.spark
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
var sc: SparkContext = _
|
||||
var numParallelism: Int = _
|
||||
class SparkParallelismTrackerSuite extends FunSuite with PerTest {
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
val conf: SparkConf = new SparkConf()
|
||||
.setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite")
|
||||
sc = new SparkContext(conf)
|
||||
numParallelism = sc.defaultParallelism
|
||||
}
|
||||
val numParallelism: Int = Runtime.getRuntime.availableProcessors()
|
||||
|
||||
test("tracker should not affect execution result") {
|
||||
override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
||||
.master("local[*]")
|
||||
.appName("XGBoostSuite")
|
||||
.config("spark.ui.enabled", true)
|
||||
.config("spark.driver.memory", "512m")
|
||||
.config("spark.task.cpus", 1)
|
||||
|
||||
test("tracker should not affect execution result when timeout is not larger than 0") {
|
||||
val nWorkers = numParallelism
|
||||
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
||||
val tracker = new SparkParallelismTracker(sc, 10000, nWorkers)
|
||||
@@ -54,4 +56,21 @@ class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("tracker should throw exception if parallelism is not sufficient with" +
|
||||
" spark.task.cpus larger than 1") {
|
||||
sc.conf.set("spark.task.cpus", "2")
|
||||
val nWorkers = numParallelism
|
||||
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
|
||||
val tracker = new SparkParallelismTracker(sc, 1000, nWorkers)
|
||||
intercept[IllegalStateException] {
|
||||
tracker.execute {
|
||||
rdd.map { i =>
|
||||
// Test interruption
|
||||
Thread.sleep(Long.MaxValue)
|
||||
i
|
||||
}.sum()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user