allow the user to specify the worker number and avoid unnecessary shuffle

This commit is contained in:
CodingCat
2016-03-10 06:58:30 -05:00
parent e0a3f1c000
commit d47df5c1d8
2 changed files with 29 additions and 13 deletions

View File

@@ -35,7 +35,7 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
class XGBoostSuite extends FunSuite with BeforeAndAfter {
private implicit var sc: SparkContext = null
private val numWorker = 2
private val numWorkers = 4
private class EvalError extends EvalTrait {
@@ -114,10 +114,10 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
private def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
sparkContext.getOrElse(sc).parallelize(sampleList, numWorker)
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
}
test("build RDD containing boosters") {
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = buildTrainingRDD()
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
@@ -127,13 +127,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
numWorker, 2, null, null)
numWorkers = 2, round = 5, null, null)
val boosterCount = boosterRDD.count()
assert(boosterCount === numWorker)
assert(boosterCount === 2)
val boosters = boosterRDD.collect()
for (booster <- boosters) {
val predicts = booster.predict(testSetDMatrix, true)
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1)
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.17)
}
}
@@ -157,13 +157,12 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
}
test("nthread configuration must be equal to spark.task.cpus") {
// close the current Spark context
sc.stop()
sc = null
// start another app
val sparkConf = new SparkConf().setMaster("local[*]").set("spark.task.cpus", "4").
setAppName("test1")
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
set("spark.task.cpus", "4")
val customSparkContext = new SparkContext(sparkConf)
// start another app
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic", "nthread" -> 6).toMap