allow the user to specify the worker number and avoid unnecessary shuffle

This commit is contained in:
CodingCat 2016-03-10 06:58:30 -05:00
parent e0a3f1c000
commit d47df5c1d8
2 changed files with 29 additions and 13 deletions

View File

@ -43,7 +43,16 @@ object XGBoost extends Serializable {
rabitEnv: mutable.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
import DataUtils._
trainingData.repartition(numWorkers).mapPartitions {
val partitionedData = {
if (numWorkers > trainingData.partitions.length) {
trainingData.repartition(numWorkers)
} else if (numWorkers < trainingData.partitions.length) {
trainingData.coalesce(numWorkers)
} else {
trainingData
}
}
partitionedData.mapPartitions {
trainingSamples =>
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
@ -60,6 +69,8 @@ object XGBoost extends Serializable {
* @param trainingData the trainingset represented as RDD
* @param configMap Map containing the configuration entries
* @param round the number of iterations
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
* workers equals to the partition number of trainingData RDD
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
@ -67,8 +78,7 @@ object XGBoost extends Serializable {
*/
@throws(classOf[XGBoostError])
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
val numWorkers = trainingData.partitions.length
nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
implicit val sc = trainingData.sparkContext
if (configMap.contains("nthread")) {
val nThread = configMap("nthread")
@ -77,6 +87,13 @@ object XGBoost extends Serializable {
s"the nthread configuration ($nThread) must be no larger than " +
s"spark.task.cpus ($coresPerTask)")
}
val numWorkers = {
if (nWorkers > 0) {
nWorkers
} else {
trainingData.partitions.length
}
}
val tracker = new RabitTracker(numWorkers)
require(tracker.start(), "FAULT: Failed to start tracker")
val boosters = buildDistributedBoosters(trainingData, configMap,

View File

@ -35,7 +35,7 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
class XGBoostSuite extends FunSuite with BeforeAndAfter {
private implicit var sc: SparkContext = null
private val numWorker = 2
private val numWorkers = 4
private class EvalError extends EvalTrait {
@ -114,10 +114,10 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
private def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
sparkContext.getOrElse(sc).parallelize(sampleList, numWorker)
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
}
test("build RDD containing boosters") {
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = buildTrainingRDD()
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
@ -127,13 +127,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
numWorker, 2, null, null)
numWorkers = 2, round = 5, null, null)
val boosterCount = boosterRDD.count()
assert(boosterCount === numWorker)
assert(boosterCount === 2)
val boosters = boosterRDD.collect()
for (booster <- boosters) {
val predicts = booster.predict(testSetDMatrix, true)
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1)
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.17)
}
}
@ -157,13 +157,12 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
}
test("nthread configuration must be equal to spark.task.cpus") {
// close the current Spark context
sc.stop()
sc = null
// start another app
val sparkConf = new SparkConf().setMaster("local[*]").set("spark.task.cpus", "4").
setAppName("test1")
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
set("spark.task.cpus", "4")
val customSparkContext = new SparkContext(sparkConf)
// start another app
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic", "nthread" -> 6).toMap