nthread no larger than spark.task.cpus
This commit is contained in:
parent
bbe2b2f0b6
commit
e0a3f1c000
@ -1 +1 @@
|
||||
Subproject commit 4e6459b0bc15e6cf9b315cc75e2e5495c03cd417
|
||||
Subproject commit 1db0792e1a55355b1f07699bba18c88ded996953
|
||||
@ -70,6 +70,13 @@ object XGBoost extends Serializable {
|
||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||
val numWorkers = trainingData.partitions.length
|
||||
implicit val sc = trainingData.sparkContext
|
||||
if (configMap.contains("nthread")) {
|
||||
val nThread = configMap("nthread")
|
||||
val coresPerTask = sc.getConf.get("spark.task.cpus", "1")
|
||||
require(nThread.toString <= coresPerTask,
|
||||
s"the nthread configuration ($nThread) must be no larger than " +
|
||||
s"spark.task.cpus ($coresPerTask)")
|
||||
}
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
val boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
|
||||
@ -27,12 +27,12 @@ import org.apache.spark.mllib.linalg.DenseVector
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||
|
||||
class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
||||
class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
|
||||
private implicit var sc: SparkContext = null
|
||||
private val numWorker = 2
|
||||
@ -79,13 +79,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
||||
}
|
||||
}
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
before {
|
||||
// build SparkContext
|
||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||
sc = new SparkContext(sparkConf)
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
after {
|
||||
if (sc != null) {
|
||||
sc.stop()
|
||||
}
|
||||
@ -112,9 +112,9 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
||||
sampleList.toList
|
||||
}
|
||||
|
||||
private def buildTrainingRDD(): RDD[LabeledPoint] = {
|
||||
private def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
|
||||
val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
|
||||
sc.parallelize(sampleList, numWorker)
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorker)
|
||||
}
|
||||
|
||||
test("build RDD containing boosters") {
|
||||
@ -155,4 +155,21 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
||||
val predicts = loadedXGBooostModel.predict(testSetDMatrix)
|
||||
assert(eval.eval(predicts, testSetDMatrix) < 0.1)
|
||||
}
|
||||
|
||||
test("nthread configuration must be equal to spark.task.cpus") {
|
||||
// close the current Spark context
|
||||
sc.stop()
|
||||
sc = null
|
||||
// start another app
|
||||
val sparkConf = new SparkConf().setMaster("local[*]").set("spark.task.cpus", "4").
|
||||
setAppName("test1")
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
||||
intercept[IllegalArgumentException] {
|
||||
XGBoost.train(trainingRDD, paramMap, 5)
|
||||
}
|
||||
customSparkContext.stop()
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user