set nthread to spark.task.cpus by default

This commit is contained in:
CodingCat 2016-03-11 20:07:09 -05:00
parent cbabaeba0c
commit 5f441a29a8

View File

@ -81,12 +81,15 @@ object XGBoost extends Serializable {
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int, def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = { nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
implicit val sc = trainingData.sparkContext implicit val sc = trainingData.sparkContext
if (configMap.contains("nthread")) { var overridedConfMap = configMap
val nThread = configMap("nthread") if (overridedConfMap.contains("nthread")) {
val nThread = overridedConfMap("nthread")
val coresPerTask = sc.getConf.get("spark.task.cpus", "1") val coresPerTask = sc.getConf.get("spark.task.cpus", "1")
require(nThread.toString <= coresPerTask, require(nThread.toString <= coresPerTask,
s"the nthread configuration ($nThread) must be no larger than " + s"the nthread configuration ($nThread) must be no larger than " +
s"spark.task.cpus ($coresPerTask)") s"spark.task.cpus ($coresPerTask)")
} else {
overridedConfMap = configMap + ("nthread" -> sc.getConf.get("spark.task.cpus", "1").toInt)
} }
val numWorkers = { val numWorkers = {
if (nWorkers > 0) { if (nWorkers > 0) {
@ -97,7 +100,7 @@ object XGBoost extends Serializable {
} }
val tracker = new RabitTracker(numWorkers) val tracker = new RabitTracker(numWorkers)
require(tracker.start(), "FAULT: Failed to start tracker") require(tracker.start(), "FAULT: Failed to start tracker")
val boosters = buildDistributedBoosters(trainingData, configMap, val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
val sparkJobThread = new Thread() { val sparkJobThread = new Thread() {
override def run() { override def run() {