nthread no larger than spark.task.cpus

This commit is contained in:
CodingCat
2016-03-10 05:51:07 -05:00
parent bbe2b2f0b6
commit e0a3f1c000
3 changed files with 31 additions and 7 deletions

View File

@@ -70,6 +70,13 @@ object XGBoost extends Serializable {
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
val numWorkers = trainingData.partitions.length
implicit val sc = trainingData.sparkContext
if (configMap.contains("nthread")) {
val nThread = configMap("nthread")
val coresPerTask = sc.getConf.get("spark.task.cpus", "1")
require(nThread.toString <= coresPerTask,
s"the nthread configuration ($nThread) must be no larger than " +
s"spark.task.cpus ($coresPerTask)")
}
val tracker = new RabitTracker(numWorkers)
require(tracker.start(), "FAULT: Failed to start tracker")
val boosters = buildDistributedBoosters(trainingData, configMap,