nthread no larger than spark.task.cpus
This commit is contained in:
@@ -70,6 +70,13 @@ object XGBoost extends Serializable {
|
||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||
val numWorkers = trainingData.partitions.length
|
||||
implicit val sc = trainingData.sparkContext
|
||||
if (configMap.contains("nthread")) {
|
||||
val nThread = configMap("nthread")
|
||||
val coresPerTask = sc.getConf.get("spark.task.cpus", "1")
|
||||
require(nThread.toString <= coresPerTask,
|
||||
s"the nthread configuration ($nThread) must be no larger than " +
|
||||
s"spark.task.cpus ($coresPerTask)")
|
||||
}
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
val boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
|
||||
Reference in New Issue
Block a user