allow the user to specify the worker number and avoid unnecessary shuffle
This commit is contained in:
@@ -35,7 +35,7 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||
class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
|
||||
private implicit var sc: SparkContext = null
|
||||
private val numWorker = 2
|
||||
private val numWorkers = 4
|
||||
|
||||
private class EvalError extends EvalTrait {
|
||||
|
||||
@@ -114,10 +114,10 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
|
||||
private def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
|
||||
val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorker)
|
||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||
}
|
||||
|
||||
test("build RDD containing boosters") {
|
||||
test("build RDD containing boosters with the specified worker number") {
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
@@ -127,13 +127,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new scala.collection.mutable.HashMap[String, String],
|
||||
numWorker, 2, null, null)
|
||||
numWorkers = 2, round = 5, null, null)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === numWorker)
|
||||
assert(boosterCount === 2)
|
||||
val boosters = boosterRDD.collect()
|
||||
for (booster <- boosters) {
|
||||
val predicts = booster.predict(testSetDMatrix, true)
|
||||
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1)
|
||||
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.17)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,13 +157,12 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
}
|
||||
|
||||
test("nthread configuration must be equal to spark.task.cpus") {
|
||||
// close the current Spark context
|
||||
sc.stop()
|
||||
sc = null
|
||||
// start another app
|
||||
val sparkConf = new SparkConf().setMaster("local[*]").set("spark.task.cpus", "4").
|
||||
setAppName("test1")
|
||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
|
||||
set("spark.task.cpus", "4")
|
||||
val customSparkContext = new SparkContext(sparkConf)
|
||||
// start another app
|
||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
||||
|
||||
Reference in New Issue
Block a user