From d47df5c1d803463f2e2f83207a604391f1e56f93 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 10 Mar 2016 06:58:30 -0500 Subject: [PATCH] allow the user to specify the worker number and avoid unnecessary shuffle --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 23 ++++++++++++++++--- .../xgboost4j/scala/spark/XGBoostSuite.scala | 19 ++++++++------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index a68526b17..0122445c6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -43,7 +43,16 @@ object XGBoost extends Serializable { rabitEnv: mutable.Map[String, String], numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { import DataUtils._ - trainingData.repartition(numWorkers).mapPartitions { + val partitionedData = { + if (numWorkers > trainingData.partitions.length) { + trainingData.repartition(numWorkers) + } else if (numWorkers < trainingData.partitions.length) { + trainingData.coalesce(numWorkers) + } else { + trainingData + } + } + partitionedData.mapPartitions { trainingSamples => rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) Rabit.init(rabitEnv.asJava) @@ -60,6 +69,8 @@ object XGBoost extends Serializable { * @param trainingData the trainingset represented as RDD * @param configMap Map containing the configuration entries * @param round the number of iterations + * @param nWorkers the number of xgboost workers, 0 by default which means that the number of + * workers equals to the partition number of trainingData RDD * @param obj the user-defined objective function, null by default * @param eval the user-defined evaluation function, null by default * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed @@ -67,8 +78,7 @@ object XGBoost extends Serializable { */ @throws(classOf[XGBoostError]) def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int, - obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = { - val numWorkers = trainingData.partitions.length + nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = { implicit val sc = trainingData.sparkContext if (configMap.contains("nthread")) { val nThread = configMap("nthread") @@ -77,6 +87,13 @@ object XGBoost extends Serializable { s"the nthread configuration ($nThread) must be no larger than " + s"spark.task.cpus ($coresPerTask)") } + val numWorkers = { + if (nWorkers > 0) { + nWorkers + } else { + trainingData.partitions.length + } + } val tracker = new RabitTracker(numWorkers) require(tracker.start(), "FAULT: Failed to start tracker") val boosters = buildDistributedBoosters(trainingData, configMap, diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index 21ef3f7a2..91a12530a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -35,7 +35,7 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} class XGBoostSuite extends FunSuite with BeforeAndAfter { private implicit var sc: SparkContext = null - private val numWorker = 2 + private val numWorkers = 4 private class EvalError extends EvalTrait { @@ -114,10 +114,10 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { private def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = { val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile) - sparkContext.getOrElse(sc).parallelize(sampleList, numWorker) + sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers) } - test("build RDD containing boosters") { + test("build RDD containing boosters with the specified worker number") { val trainingRDD = buildTrainingRDD() val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator import DataUtils._ @@ -127,13 +127,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", "objective" -> "binary:logistic").toMap, new scala.collection.mutable.HashMap[String, String], - numWorker, 2, null, null) + numWorkers = 2, round = 5, null, null) val boosterCount = boosterRDD.count() - assert(boosterCount === numWorker) + assert(boosterCount === 2) val boosters = boosterRDD.collect() for (booster <- boosters) { val predicts = booster.predict(testSetDMatrix, true) - assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1) + assert(new EvalError().eval(predicts, testSetDMatrix) < 0.17) } } @@ -157,13 +157,12 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { } test("nthread configuration must be equal to spark.task.cpus") { - // close the current Spark context sc.stop() sc = null - // start another app - val sparkConf = new SparkConf().setMaster("local[*]").set("spark.task.cpus", "4"). - setAppName("test1") + val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite"). + set("spark.task.cpus", "4") val customSparkContext = new SparkContext(sparkConf) + // start another app val trainingRDD = buildTrainingRDD(Some(customSparkContext)) val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", "objective" -> "binary:logistic", "nthread" -> 6).toMap