[FLINK] remove nWorker from API
This commit is contained in:
@@ -82,9 +82,9 @@ object XGBoost {
|
||||
* @param params The parameters to XGBoost.
|
||||
* @param round Number of rounds to train.
|
||||
*/
|
||||
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int, nWorkers: Int):
|
||||
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
|
||||
XGBoostModel = {
|
||||
val tracker = new RabitTracker(nWorkers)
|
||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||
if (tracker.start()) {
|
||||
dtrain
|
||||
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
|
||||
|
||||
Reference in New Issue
Block a user