allow the user define how many workers they need
This commit is contained in:
parent
909c6af330
commit
a08cc8aad4
@ -22,17 +22,17 @@ import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
|||||||
|
|
||||||
object DistTrainWithSpark {
|
object DistTrainWithSpark {
|
||||||
def main(args: Array[String]): Unit = {
|
def main(args: Array[String]): Unit = {
|
||||||
if (args.length != 3) {
|
if (args.length != 4) {
|
||||||
println(
|
println(
|
||||||
"usage: program num_of_rounds training_path model_path")
|
"usage: program num_of_rounds num_workers training_path model_path")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
}
|
}
|
||||||
val sc = new SparkContext()
|
val sc = new SparkContext()
|
||||||
val inputTrainPath = args(1)
|
val inputTrainPath = args(2)
|
||||||
val outputModelPath = args(2)
|
val outputModelPath = args(3)
|
||||||
// number of iterations
|
// number of iterations
|
||||||
val numRound = args(0).toInt
|
val numRound = args(0).toInt
|
||||||
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
|
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).repartition(args(1).toInt)
|
||||||
// training parameters
|
// training parameters
|
||||||
val paramMap = List(
|
val paramMap = List(
|
||||||
"eta" -> 0.1f,
|
"eta" -> 0.1f,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user