[jvm-packages] consider spark.task.cpus when controlling parallelism (#3530)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* consider spark.task.cpus when controlling parallelism

* fix bug

* fix conf setup

* calculate requestedCores within ParallelismController

* enforce spark.task.cpus = 1

* unify unit test case framework

* enable spark ui
This commit is contained in:
Nan Zhu
2018-07-31 06:19:45 -07:00
committed by GitHub
parent 860263f814
commit 6cf97b4eae
4 changed files with 42 additions and 19 deletions

View File

@@ -19,7 +19,8 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.SparkContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
@@ -37,6 +38,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
.appName("XGBoostSuite")
.config("spark.ui.enabled", false)
.config("spark.driver.memory", "512m")
.config("spark.task.cpus", 1)
override def beforeEach(): Unit = getOrCreateSession

View File

@@ -16,22 +16,24 @@
package org.apache.spark
import org.scalatest.FunSuite
import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest
import org.apache.spark.rdd.RDD
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.sql.SparkSession
class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
var sc: SparkContext = _
var numParallelism: Int = _
class SparkParallelismTrackerSuite extends FunSuite with PerTest {
override def beforeAll(): Unit = {
val conf: SparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("XGBoostSuite")
sc = new SparkContext(conf)
numParallelism = sc.defaultParallelism
}
val numParallelism: Int = Runtime.getRuntime.availableProcessors()
test("tracker should not affect execution result") {
override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
.master("local[*]")
.appName("XGBoostSuite")
.config("spark.ui.enabled", true)
.config("spark.driver.memory", "512m")
.config("spark.task.cpus", 1)
test("tracker should not affect execution result when timeout is not larger than 0") {
val nWorkers = numParallelism
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
val tracker = new SparkParallelismTracker(sc, 10000, nWorkers)
@@ -54,4 +56,21 @@ class SparkParallelismTrackerSuite extends FunSuite with BeforeAndAfterAll {
}
}
}
test("tracker should throw exception if parallelism is not sufficient with" +
" spark.task.cpus larger than 1") {
sc.conf.set("spark.task.cpus", "2")
val nWorkers = numParallelism
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
val tracker = new SparkParallelismTracker(sc, 1000, nWorkers)
intercept[IllegalStateException] {
tracker.execute {
rdd.map { i =>
// Test interruption
Thread.sleep(Long.MaxValue)
i
}.sum()
}
}
}
}