[Flink] Check
This commit is contained in:
parent
2cec10c46f
commit
81dbf564a4
@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
# Simple script to test distributed version, to be deleted later.
|
||||
cd xgboost4j-demo
|
||||
java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
|
||||
cd xgboost4j-flink
|
||||
flink run -c ml.dmlc.xgboost4j.flink.Test -p 4 target/xgboost4j-flink-0.1-jar-with-dependencies.jar
|
||||
cd ..
|
||||
|
||||
47
jvm-packages/xgboost4j-flink/pom.xml
Normal file
47
jvm-packages/xgboost4j-flink/pom.xml
Normal file
@ -0,0 +1,47 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-flink</artifactId>
|
||||
<version>0.1</version>
|
||||
<packaging>jar</packaging>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-java_2.11</artifactId>
|
||||
<version>0.10.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-scala_2.11</artifactId>
|
||||
<version>0.10.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-clients_2.11</artifactId>
|
||||
<version>0.10.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-ml_2.11</artifactId>
|
||||
<version>0.10.2</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
@ -0,0 +1,71 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.Rabit
|
||||
import ml.dmlc.xgboost4j.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.XGBoost
|
||||
import org.apache.commons.logging.Log
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction
|
||||
import org.apache.flink.api.scala._
|
||||
import org.apache.flink.api.scala.DataSet
|
||||
import org.apache.flink.api.scala.ExecutionEnvironment
|
||||
import org.apache.flink.ml.common.LabeledVector
|
||||
import org.apache.flink.ml.MLUtils
|
||||
import org.apache.flink.util.Collector
|
||||
|
||||
class ScalaMapFunction(workerEnvs: java.util.Map[String, String])
|
||||
extends RichMapPartitionFunction[LabeledVector, Booster] {
|
||||
val log = LogFactory.getLog(this.getClass)
|
||||
def mapPartition(it : java.lang.Iterable[LabeledVector], collector: Collector[Booster]): Unit = {
|
||||
workerEnvs.put("DMLC_TASK_ID", this.getRuntimeContext.getTaskName)
|
||||
log.info("start with env" + workerEnvs.toString)
|
||||
Rabit.init(workerEnvs)
|
||||
|
||||
val trainMat = new DMatrix("/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
|
||||
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val watches = List("train" -> trainMat).toMap
|
||||
val round = 2
|
||||
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
|
||||
Rabit.shutdown()
|
||||
collector.collect(booster)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
object Test {
|
||||
val log = LogFactory.getLog(this.getClass)
|
||||
def main(args: Array[String]) {
|
||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
|
||||
val tracker = new RabitTracker(data.getExecutionEnvironment.getParallelism)
|
||||
log.info("start with parallelism" + data.getExecutionEnvironment.getParallelism)
|
||||
assert(data.getExecutionEnvironment.getParallelism >= 1)
|
||||
tracker.start()
|
||||
|
||||
val res = data.mapPartition(new ScalaMapFunction(tracker.getWorkerEnvs)).reduce((x, y) => x)
|
||||
val model = res.collect()
|
||||
log.info(model)
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,7 +21,7 @@ public class RabitTracker {
|
||||
// environment variable to be pased.
|
||||
private Map<String, String> envs = new HashMap<String, String>();
|
||||
// number of workers to be submitted.
|
||||
private int num_workers;
|
||||
private int numWorkers;
|
||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||
|
||||
static {
|
||||
@ -63,8 +63,12 @@ public class RabitTracker {
|
||||
}
|
||||
|
||||
|
||||
public RabitTracker(int num_workers) {
|
||||
this.num_workers = num_workers;
|
||||
public RabitTracker(int numWorkers)
|
||||
throws XGBoostError {
|
||||
if (numWorkers < 1) {
|
||||
throw new XGBoostError("numWorkers must be greater equal to one");
|
||||
}
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -100,7 +104,7 @@ public class RabitTracker {
|
||||
private boolean startTrackerProcess() {
|
||||
try {
|
||||
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
|
||||
" --num-workers=" + String.valueOf(num_workers)));
|
||||
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)));
|
||||
loadEnvs(trackerProcess.get().getInputStream());
|
||||
return true;
|
||||
} catch (IOException ioe) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user