diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 4b4da36cb..5fcdf81e5 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -81,12 +81,15 @@ object XGBoost extends Serializable { def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int, nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = { implicit val sc = trainingData.sparkContext - if (configMap.contains("nthread")) { - val nThread = configMap("nthread") + var overridedConfMap = configMap + if (overridedConfMap.contains("nthread")) { + val nThread = overridedConfMap("nthread") val coresPerTask = sc.getConf.get("spark.task.cpus", "1") require(nThread.toString <= coresPerTask, s"the nthread configuration ($nThread) must be no larger than " + s"spark.task.cpus ($coresPerTask)") + } else { + overridedConfMap = configMap + ("nthread" -> sc.getConf.get("spark.task.cpus", "1").toInt) } val numWorkers = { if (nWorkers > 0) { @@ -97,7 +100,7 @@ object XGBoost extends Serializable { } val tracker = new RabitTracker(numWorkers) require(tracker.start(), "FAULT: Failed to start tracker") - val boosters = buildDistributedBoosters(trainingData, configMap, + val boosters = buildDistributedBoosters(trainingData, overridedConfMap, tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) val sparkJobThread = new Thread() { override def run() {