[Flink] Check

This commit is contained in:
tqchen 2016-03-05 09:38:43 -08:00 committed by CodingCat
parent 2cec10c46f
commit 81dbf564a4
4 changed files with 128 additions and 6 deletions

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Simple script to test distributed version, to be deleted later.
cd xgboost4j-demo
java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
cd xgboost4j-flink
flink run -c ml.dmlc.xgboost4j.flink.Test -p 4 target/xgboost4j-flink-0.1-jar-with-dependencies.jar
cd ..

View File

@ -0,0 +1,47 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
<artifactId>xgboost4j-flink</artifactId>
<version>0.1</version>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-java_2.11</artifactId>
<version>0.10.2</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-scala_2.11</artifactId>
<version>0.10.2</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-clients_2.11</artifactId>
<version>0.10.2</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-ml_2.11</artifactId>
<version>0.10.2</version>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,71 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.flink
import ml.dmlc.xgboost4j.Rabit
import ml.dmlc.xgboost4j.RabitTracker
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.XGBoost
import org.apache.commons.logging.Log
import org.apache.commons.logging.LogFactory
import org.apache.flink.api.common.functions.RichMapPartitionFunction
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.DataSet
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.ml.MLUtils
import org.apache.flink.util.Collector
class ScalaMapFunction(workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, Booster] {
val log = LogFactory.getLog(this.getClass)
def mapPartition(it : java.lang.Iterable[LabeledVector], collector: Collector[Booster]): Unit = {
workerEnvs.put("DMLC_TASK_ID", this.getRuntimeContext.getTaskName)
log.info("start with env" + workerEnvs.toString)
Rabit.init(workerEnvs)
val trainMat = new DMatrix("/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val watches = List("train" -> trainMat).toMap
val round = 2
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
Rabit.shutdown()
collector.collect(booster)
}
}
object Test {
val log = LogFactory.getLog(this.getClass)
def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
val tracker = new RabitTracker(data.getExecutionEnvironment.getParallelism)
log.info("start with parallelism" + data.getExecutionEnvironment.getParallelism)
assert(data.getExecutionEnvironment.getParallelism >= 1)
tracker.start()
val res = data.mapPartition(new ScalaMapFunction(tracker.getWorkerEnvs)).reduce((x, y) => x)
val model = res.collect()
log.info(model)
}
}

View File

@ -21,7 +21,7 @@ public class RabitTracker {
// environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int num_workers;
private int numWorkers;
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
static {
@ -63,8 +63,12 @@ public class RabitTracker {
}
public RabitTracker(int num_workers) {
this.num_workers = num_workers;
public RabitTracker(int numWorkers)
throws XGBoostError {
if (numWorkers < 1) {
throw new XGBoostError("numWorkers must be greater equal to one");
}
this.numWorkers = numWorkers;
}
/**
@ -100,7 +104,7 @@ public class RabitTracker {
private boolean startTrackerProcess() {
try {
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
" --num-workers=" + String.valueOf(num_workers)));
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {