[Flink] Check
This commit is contained in:
parent
2cec10c46f
commit
81dbf564a4
@ -1,5 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Simple script to test distributed version, to be deleted later.
|
# Simple script to test distributed version, to be deleted later.
|
||||||
cd xgboost4j-demo
|
cd xgboost4j-flink
|
||||||
java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
|
flink run -c ml.dmlc.xgboost4j.flink.Test -p 4 target/xgboost4j-flink-0.1-jar-with-dependencies.jar
|
||||||
cd ..
|
cd ..
|
||||||
|
|||||||
47
jvm-packages/xgboost4j-flink/pom.xml
Normal file
47
jvm-packages/xgboost4j-flink/pom.xml
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
<parent>
|
||||||
|
<groupId>org.dmlc</groupId>
|
||||||
|
<artifactId>xgboostjvm</artifactId>
|
||||||
|
<version>0.1</version>
|
||||||
|
</parent>
|
||||||
|
<artifactId>xgboost4j-flink</artifactId>
|
||||||
|
<version>0.1</version>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.dmlc</groupId>
|
||||||
|
<artifactId>xgboost4j</artifactId>
|
||||||
|
<version>0.1</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-lang3</artifactId>
|
||||||
|
<version>3.4</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.flink</groupId>
|
||||||
|
<artifactId>flink-java_2.11</artifactId>
|
||||||
|
<version>0.10.2</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.flink</groupId>
|
||||||
|
<artifactId>flink-scala_2.11</artifactId>
|
||||||
|
<version>0.10.2</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.flink</groupId>
|
||||||
|
<artifactId>flink-clients_2.11</artifactId>
|
||||||
|
<version>0.10.2</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.flink</groupId>
|
||||||
|
<artifactId>flink-ml_2.11</artifactId>
|
||||||
|
<version>0.10.2</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
</project>
|
||||||
@ -0,0 +1,71 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.flink
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.Rabit
|
||||||
|
import ml.dmlc.xgboost4j.RabitTracker
|
||||||
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
import ml.dmlc.xgboost4j.scala.XGBoost
|
||||||
|
import org.apache.commons.logging.Log
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import org.apache.flink.api.common.functions.RichMapPartitionFunction
|
||||||
|
import org.apache.flink.api.scala._
|
||||||
|
import org.apache.flink.api.scala.DataSet
|
||||||
|
import org.apache.flink.api.scala.ExecutionEnvironment
|
||||||
|
import org.apache.flink.ml.common.LabeledVector
|
||||||
|
import org.apache.flink.ml.MLUtils
|
||||||
|
import org.apache.flink.util.Collector
|
||||||
|
|
||||||
|
class ScalaMapFunction(workerEnvs: java.util.Map[String, String])
|
||||||
|
extends RichMapPartitionFunction[LabeledVector, Booster] {
|
||||||
|
val log = LogFactory.getLog(this.getClass)
|
||||||
|
def mapPartition(it : java.lang.Iterable[LabeledVector], collector: Collector[Booster]): Unit = {
|
||||||
|
workerEnvs.put("DMLC_TASK_ID", this.getRuntimeContext.getTaskName)
|
||||||
|
log.info("start with env" + workerEnvs.toString)
|
||||||
|
Rabit.init(workerEnvs)
|
||||||
|
|
||||||
|
val trainMat = new DMatrix("/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
|
||||||
|
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val watches = List("train" -> trainMat).toMap
|
||||||
|
val round = 2
|
||||||
|
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
|
||||||
|
Rabit.shutdown()
|
||||||
|
collector.collect(booster)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
object Test {
|
||||||
|
val log = LogFactory.getLog(this.getClass)
|
||||||
|
def main(args: Array[String]) {
|
||||||
|
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||||
|
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
|
||||||
|
val tracker = new RabitTracker(data.getExecutionEnvironment.getParallelism)
|
||||||
|
log.info("start with parallelism" + data.getExecutionEnvironment.getParallelism)
|
||||||
|
assert(data.getExecutionEnvironment.getParallelism >= 1)
|
||||||
|
tracker.start()
|
||||||
|
|
||||||
|
val res = data.mapPartition(new ScalaMapFunction(tracker.getWorkerEnvs)).reduce((x, y) => x)
|
||||||
|
val model = res.collect()
|
||||||
|
log.info(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ public class RabitTracker {
|
|||||||
// environment variable to be pased.
|
// environment variable to be pased.
|
||||||
private Map<String, String> envs = new HashMap<String, String>();
|
private Map<String, String> envs = new HashMap<String, String>();
|
||||||
// number of workers to be submitted.
|
// number of workers to be submitted.
|
||||||
private int num_workers;
|
private int numWorkers;
|
||||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||||
|
|
||||||
static {
|
static {
|
||||||
@ -63,8 +63,12 @@ public class RabitTracker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public RabitTracker(int num_workers) {
|
public RabitTracker(int numWorkers)
|
||||||
this.num_workers = num_workers;
|
throws XGBoostError {
|
||||||
|
if (numWorkers < 1) {
|
||||||
|
throw new XGBoostError("numWorkers must be greater equal to one");
|
||||||
|
}
|
||||||
|
this.numWorkers = numWorkers;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -100,7 +104,7 @@ public class RabitTracker {
|
|||||||
private boolean startTrackerProcess() {
|
private boolean startTrackerProcess() {
|
||||||
try {
|
try {
|
||||||
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
|
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
|
||||||
" --num-workers=" + String.valueOf(num_workers)));
|
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)));
|
||||||
loadEnvs(trackerProcess.get().getInputStream());
|
loadEnvs(trackerProcess.get().getInputStream());
|
||||||
return true;
|
return true;
|
||||||
} catch (IOException ioe) {
|
} catch (IOException ioe) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user