diff --git a/dmlc-core b/dmlc-core index 4e6459b0b..1db0792e1 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 4e6459b0bc15e6cf9b315cc75e2e5495c03cd417 +Subproject commit 1db0792e1a55355b1f07699bba18c88ded996953 diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 5d1a05c17..a68526b17 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -70,6 +70,13 @@ object XGBoost extends Serializable { obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = { val numWorkers = trainingData.partitions.length implicit val sc = trainingData.sparkContext + if (configMap.contains("nthread")) { + val nThread = configMap("nthread") + val coresPerTask = sc.getConf.get("spark.task.cpus", "1") + require(nThread.toString <= coresPerTask, + s"the nthread configuration ($nThread) must be no larger than " + + s"spark.task.cpus ($coresPerTask)") + } val tracker = new RabitTracker(numWorkers) require(tracker.start(), "FAULT: Failed to start tracker") val boosters = buildDistributedBoosters(trainingData, configMap, diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index 98964f23e..21ef3f7a2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -27,12 +27,12 @@ import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, FunSuite} import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError} import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} -class XGBoostSuite extends FunSuite with BeforeAndAfterAll { +class XGBoostSuite extends FunSuite with BeforeAndAfter { private implicit var sc: SparkContext = null private val numWorker = 2 @@ -79,13 +79,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll { } } - override def beforeAll(): Unit = { + before { // build SparkContext val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite") sc = new SparkContext(sparkConf) } - override def afterAll(): Unit = { + after { if (sc != null) { sc.stop() } @@ -112,9 +112,9 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll { sampleList.toList } - private def buildTrainingRDD(): RDD[LabeledPoint] = { + private def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = { val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile) - sc.parallelize(sampleList, numWorker) + sparkContext.getOrElse(sc).parallelize(sampleList, numWorker) } test("build RDD containing boosters") { @@ -155,4 +155,21 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll { val predicts = loadedXGBooostModel.predict(testSetDMatrix) assert(eval.eval(predicts, testSetDMatrix) < 0.1) } + + test("nthread configuration must be equal to spark.task.cpus") { + // close the current Spark context + sc.stop() + sc = null + // start another app + val sparkConf = new SparkConf().setMaster("local[*]").set("spark.task.cpus", "4"). + setAppName("test1") + val customSparkContext = new SparkContext(sparkConf) + val trainingRDD = buildTrainingRDD(Some(customSparkContext)) + val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + "objective" -> "binary:logistic", "nthread" -> 6).toMap + intercept[IllegalArgumentException] { + XGBoost.train(trainingRDD, paramMap, 5) + } + customSparkContext.stop() + } }