allow the user to specify the worker number and avoid unnecessary shuffle
This commit is contained in:
@@ -43,7 +43,16 @@ object XGBoost extends Serializable {
|
||||
rabitEnv: mutable.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
trainingData.repartition(numWorkers).mapPartitions {
|
||||
val partitionedData = {
|
||||
if (numWorkers > trainingData.partitions.length) {
|
||||
trainingData.repartition(numWorkers)
|
||||
} else if (numWorkers < trainingData.partitions.length) {
|
||||
trainingData.coalesce(numWorkers)
|
||||
} else {
|
||||
trainingData
|
||||
}
|
||||
}
|
||||
partitionedData.mapPartitions {
|
||||
trainingSamples =>
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
@@ -60,6 +69,8 @@ object XGBoost extends Serializable {
|
||||
* @param trainingData the trainingset represented as RDD
|
||||
* @param configMap Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj the user-defined objective function, null by default
|
||||
* @param eval the user-defined evaluation function, null by default
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
@@ -67,8 +78,7 @@ object XGBoost extends Serializable {
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||
val numWorkers = trainingData.partitions.length
|
||||
nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||
implicit val sc = trainingData.sparkContext
|
||||
if (configMap.contains("nthread")) {
|
||||
val nThread = configMap("nthread")
|
||||
@@ -77,6 +87,13 @@ object XGBoost extends Serializable {
|
||||
s"the nthread configuration ($nThread) must be no larger than " +
|
||||
s"spark.task.cpus ($coresPerTask)")
|
||||
}
|
||||
val numWorkers = {
|
||||
if (nWorkers > 0) {
|
||||
nWorkers
|
||||
} else {
|
||||
trainingData.partitions.length
|
||||
}
|
||||
}
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
val boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
|
||||
Reference in New Issue
Block a user