diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/spark/example/DistTrainWithSpark.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/spark/example/DistTrainWithSpark.scala index 5a3bb0676..4f501825f 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/spark/example/DistTrainWithSpark.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/spark/example/DistTrainWithSpark.scala @@ -22,17 +22,17 @@ import ml.dmlc.xgboost4j.scala.spark.XGBoost object DistTrainWithSpark { def main(args: Array[String]): Unit = { - if (args.length != 3) { + if (args.length != 4) { println( - "usage: program num_of_rounds training_path model_path") + "usage: program num_of_rounds num_workers training_path model_path") sys.exit(1) } val sc = new SparkContext() - val inputTrainPath = args(1) - val outputModelPath = args(2) + val inputTrainPath = args(2) + val outputModelPath = args(3) // number of iterations val numRound = args(0).toInt - val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath) + val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).repartition(args(1).toInt) // training parameters val paramMap = List( "eta" -> 0.1f,