From 81dbf564a475919ebd7a2bcf75ce00162c15e741 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 5 Mar 2016 09:38:43 -0800 Subject: [PATCH] [Flink] Check --- jvm-packages/test_distributed.sh | 4 +- jvm-packages/xgboost4j-flink/pom.xml | 47 ++++++++++++ .../scala/ml/dmlc/xgboost4j/flink/Test.scala | 71 +++++++++++++++++++ .../java/ml/dmlc/xgboost4j/RabitTracker.java | 12 ++-- 4 files changed, 128 insertions(+), 6 deletions(-) create mode 100644 jvm-packages/xgboost4j-flink/pom.xml create mode 100644 jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala diff --git a/jvm-packages/test_distributed.sh b/jvm-packages/test_distributed.sh index c9a5b21be..b17f6a3b3 100755 --- a/jvm-packages/test_distributed.sh +++ b/jvm-packages/test_distributed.sh @@ -1,5 +1,5 @@ #!/bin/bash # Simple script to test distributed version, to be deleted later. -cd xgboost4j-demo -java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4 +cd xgboost4j-flink +flink run -c ml.dmlc.xgboost4j.flink.Test -p 4 target/xgboost4j-flink-0.1-jar-with-dependencies.jar cd .. diff --git a/jvm-packages/xgboost4j-flink/pom.xml b/jvm-packages/xgboost4j-flink/pom.xml new file mode 100644 index 000000000..9c7b55427 --- /dev/null +++ b/jvm-packages/xgboost4j-flink/pom.xml @@ -0,0 +1,47 @@ + + + 4.0.0 + + org.dmlc + xgboostjvm + 0.1 + + xgboost4j-flink + 0.1 + jar + + + org.dmlc + xgboost4j + 0.1 + + + org.apache.commons + commons-lang3 + 3.4 + + + org.apache.flink + flink-java_2.11 + 0.10.2 + + + org.apache.flink + flink-scala_2.11 + 0.10.2 + + + org.apache.flink + flink-clients_2.11 + 0.10.2 + + + org.apache.flink + flink-ml_2.11 + 0.10.2 + + + + \ No newline at end of file diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala new file mode 100644 index 000000000..65de48277 --- /dev/null +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala @@ -0,0 +1,71 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.flink + +import ml.dmlc.xgboost4j.Rabit +import ml.dmlc.xgboost4j.RabitTracker +import ml.dmlc.xgboost4j.scala.Booster +import ml.dmlc.xgboost4j.scala.DMatrix +import ml.dmlc.xgboost4j.scala.XGBoost +import org.apache.commons.logging.Log +import org.apache.commons.logging.LogFactory +import org.apache.flink.api.common.functions.RichMapPartitionFunction +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.DataSet +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.ml.common.LabeledVector +import org.apache.flink.ml.MLUtils +import org.apache.flink.util.Collector + +class ScalaMapFunction(workerEnvs: java.util.Map[String, String]) + extends RichMapPartitionFunction[LabeledVector, Booster] { + val log = LogFactory.getLog(this.getClass) + def mapPartition(it : java.lang.Iterable[LabeledVector], collector: Collector[Booster]): Unit = { + workerEnvs.put("DMLC_TASK_ID", this.getRuntimeContext.getTaskName) + log.info("start with env" + workerEnvs.toString) + Rabit.init(workerEnvs) + + val trainMat = new DMatrix("/home/tqchen/github/xgboost/demo/data/agaricus.txt.train") + + val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic").toMap + val watches = List("train" -> trainMat).toMap + val round = 2 + val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null) + Rabit.shutdown() + collector.collect(booster) + } +} + + + +object Test { + val log = LogFactory.getLog(this.getClass) + def main(args: Array[String]) { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train") + val tracker = new RabitTracker(data.getExecutionEnvironment.getParallelism) + log.info("start with parallelism" + data.getExecutionEnvironment.getParallelism) + assert(data.getExecutionEnvironment.getParallelism >= 1) + tracker.start() + + val res = data.mapPartition(new ScalaMapFunction(tracker.getWorkerEnvs)).reduce((x, y) => x) + val model = res.collect() + log.info(model) + } +} + diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java index 99adcb7e5..e6f1c8a24 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java @@ -21,7 +21,7 @@ public class RabitTracker { // environment variable to be pased. private Map envs = new HashMap(); // number of workers to be submitted. - private int num_workers; + private int numWorkers; private AtomicReference trackerProcess = new AtomicReference(); static { @@ -63,8 +63,12 @@ public class RabitTracker { } - public RabitTracker(int num_workers) { - this.num_workers = num_workers; + public RabitTracker(int numWorkers) + throws XGBoostError { + if (numWorkers < 1) { + throw new XGBoostError("numWorkers must be greater equal to one"); + } + this.numWorkers = numWorkers; } /** @@ -100,7 +104,7 @@ public class RabitTracker { private boolean startTrackerProcess() { try { trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py + - " --num-workers=" + String.valueOf(num_workers))); + " --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers))); loadEnvs(trackerProcess.get().getInputStream()); return true; } catch (IOException ioe) {