Merge pull request #954 from CodingCat/worker_num
[jvm-packages] allow the user to specify the worker number and avoid unnecessary shuffle
This commit is contained in:
commit
d913cbbfbc
@ -1 +1 @@
|
|||||||
Subproject commit 4e6459b0bc15e6cf9b315cc75e2e5495c03cd417
|
Subproject commit 1db0792e1a55355b1f07699bba18c88ded996953
|
||||||
@ -43,7 +43,16 @@ object XGBoost extends Serializable {
|
|||||||
rabitEnv: mutable.Map[String, String],
|
rabitEnv: mutable.Map[String, String],
|
||||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
trainingData.repartition(numWorkers).mapPartitions {
|
val partitionedData = {
|
||||||
|
if (numWorkers > trainingData.partitions.length) {
|
||||||
|
trainingData.repartition(numWorkers)
|
||||||
|
} else if (numWorkers < trainingData.partitions.length) {
|
||||||
|
trainingData.coalesce(numWorkers)
|
||||||
|
} else {
|
||||||
|
trainingData
|
||||||
|
}
|
||||||
|
}
|
||||||
|
partitionedData.mapPartitions {
|
||||||
trainingSamples =>
|
trainingSamples =>
|
||||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
@ -60,6 +69,8 @@ object XGBoost extends Serializable {
|
|||||||
* @param trainingData the trainingset represented as RDD
|
* @param trainingData the trainingset represented as RDD
|
||||||
* @param configMap Map containing the configuration entries
|
* @param configMap Map containing the configuration entries
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
|
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||||
|
* workers equals to the partition number of trainingData RDD
|
||||||
* @param obj the user-defined objective function, null by default
|
* @param obj the user-defined objective function, null by default
|
||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval the user-defined evaluation function, null by default
|
||||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||||
@ -67,9 +78,22 @@ object XGBoost extends Serializable {
|
|||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
nWorkers: Int = 0, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||||
val numWorkers = trainingData.partitions.length
|
|
||||||
implicit val sc = trainingData.sparkContext
|
implicit val sc = trainingData.sparkContext
|
||||||
|
if (configMap.contains("nthread")) {
|
||||||
|
val nThread = configMap("nthread")
|
||||||
|
val coresPerTask = sc.getConf.get("spark.task.cpus", "1")
|
||||||
|
require(nThread.toString <= coresPerTask,
|
||||||
|
s"the nthread configuration ($nThread) must be no larger than " +
|
||||||
|
s"spark.task.cpus ($coresPerTask)")
|
||||||
|
}
|
||||||
|
val numWorkers = {
|
||||||
|
if (nWorkers > 0) {
|
||||||
|
nWorkers
|
||||||
|
} else {
|
||||||
|
trainingData.partitions.length
|
||||||
|
}
|
||||||
|
}
|
||||||
val tracker = new RabitTracker(numWorkers)
|
val tracker = new RabitTracker(numWorkers)
|
||||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||||
val boosters = buildDistributedBoosters(trainingData, configMap,
|
val boosters = buildDistributedBoosters(trainingData, configMap,
|
||||||
|
|||||||
@ -27,15 +27,15 @@ import org.apache.spark.mllib.linalg.DenseVector
|
|||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||||
|
|
||||||
class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||||
|
|
||||||
private implicit var sc: SparkContext = null
|
private implicit var sc: SparkContext = null
|
||||||
private val numWorker = 2
|
private val numWorkers = 4
|
||||||
|
|
||||||
private class EvalError extends EvalTrait {
|
private class EvalError extends EvalTrait {
|
||||||
|
|
||||||
@ -79,13 +79,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
before {
|
||||||
// build SparkContext
|
// build SparkContext
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||||
sc = new SparkContext(sparkConf)
|
sc = new SparkContext(sparkConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def afterAll(): Unit = {
|
after {
|
||||||
if (sc != null) {
|
if (sc != null) {
|
||||||
sc.stop()
|
sc.stop()
|
||||||
}
|
}
|
||||||
@ -112,12 +112,12 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
|||||||
sampleList.toList
|
sampleList.toList
|
||||||
}
|
}
|
||||||
|
|
||||||
private def buildTrainingRDD(): RDD[LabeledPoint] = {
|
private def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
|
||||||
val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
|
val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
|
||||||
sc.parallelize(sampleList, numWorker)
|
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("build RDD containing boosters") {
|
test("build RDD containing boosters with the specified worker number") {
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD()
|
||||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
@ -127,13 +127,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
|||||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
new scala.collection.mutable.HashMap[String, String],
|
new scala.collection.mutable.HashMap[String, String],
|
||||||
numWorker, 2, null, null)
|
numWorkers = 2, round = 5, null, null)
|
||||||
val boosterCount = boosterRDD.count()
|
val boosterCount = boosterRDD.count()
|
||||||
assert(boosterCount === numWorker)
|
assert(boosterCount === 2)
|
||||||
val boosters = boosterRDD.collect()
|
val boosters = boosterRDD.collect()
|
||||||
for (booster <- boosters) {
|
for (booster <- boosters) {
|
||||||
val predicts = booster.predict(testSetDMatrix, true)
|
val predicts = booster.predict(testSetDMatrix, true)
|
||||||
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1)
|
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.17)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,4 +155,20 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
|||||||
val predicts = loadedXGBooostModel.predict(testSetDMatrix)
|
val predicts = loadedXGBooostModel.predict(testSetDMatrix)
|
||||||
assert(eval.eval(predicts, testSetDMatrix) < 0.1)
|
assert(eval.eval(predicts, testSetDMatrix) < 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("nthread configuration must be equal to spark.task.cpus") {
|
||||||
|
sc.stop()
|
||||||
|
sc = null
|
||||||
|
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
|
||||||
|
set("spark.task.cpus", "4")
|
||||||
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
|
// start another app
|
||||||
|
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
||||||
|
intercept[IllegalArgumentException] {
|
||||||
|
XGBoost.train(trainingRDD, paramMap, 5)
|
||||||
|
}
|
||||||
|
customSparkContext.stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user