set nthread to spark.task.cpus by default
This commit is contained in:
parent
cbabaeba0c
commit
5f441a29a8
@ -81,12 +81,15 @@ object XGBoost extends Serializable {
|
||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||
nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||
implicit val sc = trainingData.sparkContext
|
||||
if (configMap.contains("nthread")) {
|
||||
val nThread = configMap("nthread")
|
||||
var overridedConfMap = configMap
|
||||
if (overridedConfMap.contains("nthread")) {
|
||||
val nThread = overridedConfMap("nthread")
|
||||
val coresPerTask = sc.getConf.get("spark.task.cpus", "1")
|
||||
require(nThread.toString <= coresPerTask,
|
||||
s"the nthread configuration ($nThread) must be no larger than " +
|
||||
s"spark.task.cpus ($coresPerTask)")
|
||||
} else {
|
||||
overridedConfMap = configMap + ("nthread" -> sc.getConf.get("spark.task.cpus", "1").toInt)
|
||||
}
|
||||
val numWorkers = {
|
||||
if (nWorkers > 0) {
|
||||
@ -97,7 +100,7 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
val boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user