set nthread to spark.task.cpus by default

This commit is contained in:
CodingCat 2016-03-11 20:07:09 -05:00
parent cbabaeba0c
commit 5f441a29a8

View File

@ -81,12 +81,15 @@ object XGBoost extends Serializable {
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
implicit val sc = trainingData.sparkContext
if (configMap.contains("nthread")) {
val nThread = configMap("nthread")
var overridedConfMap = configMap
if (overridedConfMap.contains("nthread")) {
val nThread = overridedConfMap("nthread")
val coresPerTask = sc.getConf.get("spark.task.cpus", "1")
require(nThread.toString <= coresPerTask,
s"the nthread configuration ($nThread) must be no larger than " +
s"spark.task.cpus ($coresPerTask)")
} else {
overridedConfMap = configMap + ("nthread" -> sc.getConf.get("spark.task.cpus", "1").toInt)
}
val numWorkers = {
if (nWorkers > 0) {
@ -97,7 +100,7 @@ object XGBoost extends Serializable {
}
val tracker = new RabitTracker(numWorkers)
require(tracker.start(), "FAULT: Failed to start tracker")
val boosters = buildDistributedBoosters(trainingData, configMap,
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
val sparkJobThread = new Thread() {
override def run() {