diff --git a/doc/jvm/xgboost4j-intro.md b/doc/jvm/xgboost4j-intro.md index ffe85bd4b..699b3589b 100644 --- a/doc/jvm/xgboost4j-intro.md +++ b/doc/jvm/xgboost4j-intro.md @@ -147,7 +147,7 @@ val trainData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train") Model Training can be done as follows ```scala -val xgboostModel = XGBoost.train(trainData, paramMap, round, nWorkers) +val xgboostModel = XGBoost.train(trainData, paramMap, round) ``` diff --git a/jvm-packages/README.md b/jvm-packages/README.md index 62aa79268..e1dfb1576 100644 --- a/jvm-packages/README.md +++ b/jvm-packages/README.md @@ -72,7 +72,7 @@ object DistTrainWithSpark { "eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic").toMap - // use 5 distributed workers to train the model + // use 5 distributed workers to train the model val model = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = 5) // save model to HDFS path model.saveModelToHadoop(outputModelPath) @@ -100,9 +100,8 @@ object DistTrainWithFlink { "objective" -> "binary:logistic").toMap // number of iterations val round = 2 - val nWorkers = 5 // train the model - val model = XGBoost.train(trainData, paramMap, round, nWorkers) + val model = XGBoost.train(trainData, paramMap, round) val predTrain = model.predict(trainData.map{x => x.vector}) model.saveModelToHadoop("file:///path/to/xgboost.model") } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala index cd2dacb5a..74b24ac35 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala @@ -33,9 +33,8 @@ object DistTrainWithFlink { "objective" -> "binary:logistic").toMap // number of iterations val round = 2 - val nWorkers = 5 // train the model - val model = XGBoost.train(trainData, paramMap, round, 5) + val model = XGBoost.train(trainData, paramMap, round) val predTest = model.predict(testData.map{x => x.vector}) model.saveModelAsHadoopFile("file:///path/to/xgboost.model") } diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala index 3056d28a1..4c6adee99 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala @@ -82,9 +82,9 @@ object XGBoost { * @param params The parameters to XGBoost. * @param round Number of rounds to train. */ - def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int, nWorkers: Int): + def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int): XGBoostModel = { - val tracker = new RabitTracker(nWorkers) + val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism) if (tracker.start()) { dtrain .mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))