allow the user define how many workers they need

This commit is contained in:
CodingCat 2016-03-08 18:46:53 -05:00
parent 909c6af330
commit a08cc8aad4

View File

@ -22,17 +22,17 @@ import ml.dmlc.xgboost4j.scala.spark.XGBoost
object DistTrainWithSpark { object DistTrainWithSpark {
def main(args: Array[String]): Unit = { def main(args: Array[String]): Unit = {
if (args.length != 3) { if (args.length != 4) {
println( println(
"usage: program num_of_rounds training_path model_path") "usage: program num_of_rounds num_workers training_path model_path")
sys.exit(1) sys.exit(1)
} }
val sc = new SparkContext() val sc = new SparkContext()
val inputTrainPath = args(1) val inputTrainPath = args(2)
val outputModelPath = args(2) val outputModelPath = args(3)
// number of iterations // number of iterations
val numRound = args(0).toInt val numRound = args(0).toInt
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath) val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).repartition(args(1).toInt)
// training parameters // training parameters
val paramMap = List( val paramMap = List(
"eta" -> 0.1f, "eta" -> 0.1f,