set nthread to spark.task.cpus by default
This commit is contained in:
parent
cbabaeba0c
commit
5f441a29a8
@ -81,12 +81,15 @@ object XGBoost extends Serializable {
|
|||||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||||
nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||||
implicit val sc = trainingData.sparkContext
|
implicit val sc = trainingData.sparkContext
|
||||||
if (configMap.contains("nthread")) {
|
var overridedConfMap = configMap
|
||||||
val nThread = configMap("nthread")
|
if (overridedConfMap.contains("nthread")) {
|
||||||
|
val nThread = overridedConfMap("nthread")
|
||||||
val coresPerTask = sc.getConf.get("spark.task.cpus", "1")
|
val coresPerTask = sc.getConf.get("spark.task.cpus", "1")
|
||||||
require(nThread.toString <= coresPerTask,
|
require(nThread.toString <= coresPerTask,
|
||||||
s"the nthread configuration ($nThread) must be no larger than " +
|
s"the nthread configuration ($nThread) must be no larger than " +
|
||||||
s"spark.task.cpus ($coresPerTask)")
|
s"spark.task.cpus ($coresPerTask)")
|
||||||
|
} else {
|
||||||
|
overridedConfMap = configMap + ("nthread" -> sc.getConf.get("spark.task.cpus", "1").toInt)
|
||||||
}
|
}
|
||||||
val numWorkers = {
|
val numWorkers = {
|
||||||
if (nWorkers > 0) {
|
if (nWorkers > 0) {
|
||||||
@ -97,7 +100,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
val tracker = new RabitTracker(numWorkers)
|
val tracker = new RabitTracker(numWorkers)
|
||||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||||
val boosters = buildDistributedBoosters(trainingData, configMap,
|
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||||
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
override def run() {
|
override def run() {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user