diff --git a/jvm-packages/test_distributed.sh b/jvm-packages/test_distributed.sh
index c9a5b21be..b17f6a3b3 100755
--- a/jvm-packages/test_distributed.sh
+++ b/jvm-packages/test_distributed.sh
@@ -1,5 +1,5 @@
#!/bin/bash
# Simple script to test distributed version, to be deleted later.
-cd xgboost4j-demo
-java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
+cd xgboost4j-flink
+flink run -c ml.dmlc.xgboost4j.flink.Test -p 4 target/xgboost4j-flink-0.1-jar-with-dependencies.jar
cd ..
diff --git a/jvm-packages/xgboost4j-flink/pom.xml b/jvm-packages/xgboost4j-flink/pom.xml
new file mode 100644
index 000000000..9c7b55427
--- /dev/null
+++ b/jvm-packages/xgboost4j-flink/pom.xml
@@ -0,0 +1,47 @@
+
+
+ 4.0.0
+
+ org.dmlc
+ xgboostjvm
+ 0.1
+
+ xgboost4j-flink
+ 0.1
+ jar
+
+
+ org.dmlc
+ xgboost4j
+ 0.1
+
+
+ org.apache.commons
+ commons-lang3
+ 3.4
+
+
+ org.apache.flink
+ flink-java_2.11
+ 0.10.2
+
+
+ org.apache.flink
+ flink-scala_2.11
+ 0.10.2
+
+
+ org.apache.flink
+ flink-clients_2.11
+ 0.10.2
+
+
+ org.apache.flink
+ flink-ml_2.11
+ 0.10.2
+
+
+
+
\ No newline at end of file
diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala
new file mode 100644
index 000000000..65de48277
--- /dev/null
+++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala
@@ -0,0 +1,71 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.flink
+
+import ml.dmlc.xgboost4j.Rabit
+import ml.dmlc.xgboost4j.RabitTracker
+import ml.dmlc.xgboost4j.scala.Booster
+import ml.dmlc.xgboost4j.scala.DMatrix
+import ml.dmlc.xgboost4j.scala.XGBoost
+import org.apache.commons.logging.Log
+import org.apache.commons.logging.LogFactory
+import org.apache.flink.api.common.functions.RichMapPartitionFunction
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.DataSet
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.ml.common.LabeledVector
+import org.apache.flink.ml.MLUtils
+import org.apache.flink.util.Collector
+
+class ScalaMapFunction(workerEnvs: java.util.Map[String, String])
+ extends RichMapPartitionFunction[LabeledVector, Booster] {
+ val log = LogFactory.getLog(this.getClass)
+ def mapPartition(it : java.lang.Iterable[LabeledVector], collector: Collector[Booster]): Unit = {
+ workerEnvs.put("DMLC_TASK_ID", this.getRuntimeContext.getTaskName)
+ log.info("start with env" + workerEnvs.toString)
+ Rabit.init(workerEnvs)
+
+ val trainMat = new DMatrix("/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
+
+ val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
+ "objective" -> "binary:logistic").toMap
+ val watches = List("train" -> trainMat).toMap
+ val round = 2
+ val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
+ Rabit.shutdown()
+ collector.collect(booster)
+ }
+}
+
+
+
+object Test {
+ val log = LogFactory.getLog(this.getClass)
+ def main(args: Array[String]) {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
+ val tracker = new RabitTracker(data.getExecutionEnvironment.getParallelism)
+ log.info("start with parallelism" + data.getExecutionEnvironment.getParallelism)
+ assert(data.getExecutionEnvironment.getParallelism >= 1)
+ tracker.start()
+
+ val res = data.mapPartition(new ScalaMapFunction(tracker.getWorkerEnvs)).reduce((x, y) => x)
+ val model = res.collect()
+ log.info(model)
+ }
+}
+
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java
index 99adcb7e5..e6f1c8a24 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java
@@ -21,7 +21,7 @@ public class RabitTracker {
// environment variable to be pased.
private Map envs = new HashMap();
// number of workers to be submitted.
- private int num_workers;
+ private int numWorkers;
private AtomicReference trackerProcess = new AtomicReference();
static {
@@ -63,8 +63,12 @@ public class RabitTracker {
}
- public RabitTracker(int num_workers) {
- this.num_workers = num_workers;
+ public RabitTracker(int numWorkers)
+ throws XGBoostError {
+ if (numWorkers < 1) {
+ throw new XGBoostError("numWorkers must be greater equal to one");
+ }
+ this.numWorkers = numWorkers;
}
/**
@@ -100,7 +104,7 @@ public class RabitTracker {
private boolean startTrackerProcess() {
try {
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
- " --num-workers=" + String.valueOf(num_workers)));
+ " --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {