[FLINK] Make runnable flink

This commit is contained in:
tqchen 2016-03-05 17:58:58 -08:00
parent 3ddddfce79
commit 99dc311f6d
7 changed files with 124 additions and 34 deletions

@ -1 +1 @@
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0
Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Simple script to test distributed version, to be deleted later.
cd xgboost4j-flink
flink run -c ml.dmlc.xgboost4j.flink.Test -p 4 target/xgboost4j-flink-0.1-jar-with-dependencies.jar
flink run -c ml.dmlc.xgboost4j.flink.Test -p 2 target/xgboost4j-flink-0.1-jar-with-dependencies.jar
cd ..

View File

@ -10,6 +10,17 @@
</parent>
<artifactId>xgboost4j-flink</artifactId>
<version>0.1</version>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<skipAssembly>false</skipAssembly>
</configuration>
</plugin>
</plugins>
</build>
<packaging>jar</packaging>
<dependencies>
<dependency>

View File

@ -16,10 +16,6 @@
package ml.dmlc.xgboost4j.flink
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker}
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.XGBoost
import org.apache.commons.logging.Log
import org.apache.commons.logging.LogFactory
import org.apache.flink.api.common.functions.RichMapPartitionFunction
@ -30,26 +26,6 @@ import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.ml.MLUtils
import org.apache.flink.util.Collector
class ScalaMapFunction(workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, Booster] {
val log = LogFactory.getLog(this.getClass)
def mapPartition(it : java.lang.Iterable[LabeledVector], collector: Collector[Booster]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
log.info("start with env" + workerEnvs.toString)
Rabit.init(workerEnvs)
val trainMat = new DMatrix("/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val watches = List("train" -> trainMat).toMap
val round = 2
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
Rabit.shutdown()
collector.collect(booster)
}
}
object Test {
@ -57,13 +33,10 @@ object Test {
def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
val tracker = new RabitTracker(data.getExecutionEnvironment.getParallelism)
log.info("start with parallelism" + data.getExecutionEnvironment.getParallelism)
assert(data.getExecutionEnvironment.getParallelism >= 1)
tracker.start()
val res = data.mapPartition(new ScalaMapFunction(tracker.getWorkerEnvs)).reduce((x, y) => x)
val model = res.collect()
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val round = 2
val model = XGBoost.train(paramMap, data, round)
log.info(model)
}
}

View File

@ -0,0 +1,83 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.flink
import scala.collection.JavaConverters.asScalaIteratorConverter;
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{RabitTracker, Rabit}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
import org.apache.commons.logging.LogFactory
import org.apache.flink.api.common.functions.RichMapPartitionFunction
import org.apache.flink.api.scala.DataSet
import org.apache.flink.api.scala._
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.util.Collector
object XGBoost {
/**
* Helper map function to start the job.
*
* @param workerEnvs
*/
private class MapFunction(paramMap: Map[String, AnyRef],
round: Int,
workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
val logger = LogFactory.getLog(this.getClass)
def mapPartition(it: java.lang.Iterable[LabeledVector],
collector: Collector[XGBoostModel]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
logger.info("start with env" + workerEnvs.toString)
Rabit.init(workerEnvs)
val mapper = (x: LabeledVector) => {
val (index, value) = x.vector.toSeq.unzip
LabeledPoint.fromSparseVector(x.label.toFloat,
index.toArray, value.map(z => z.toFloat).toArray)
}
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
val trainMat = new DMatrix(dataIter, null)
val watches = List("train" -> trainMat).toMap
val round = 2
val booster = XGBoostScala.train(paramMap, trainMat, round, watches, null, null)
Rabit.shutdown()
collector.collect(new XGBoostModel(booster))
}
}
val logger = LogFactory.getLog(this.getClass)
/**
* Train a xgboost model with link.
*
* @param params The parameters to XGBoost.
* @param dtrain The training data.
* @param round Number of rounds to train.
*/
def train(params: Map[String, AnyRef],
dtrain: DataSet[LabeledVector],
round: Int): XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
if (tracker.start()) {
dtrain
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
.reduce((x, y) => x).collect().head
} else {
throw new Error("Tracker cannot be started")
null
}
}
}

View File

@ -0,0 +1,23 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.flink
import ml.dmlc.xgboost4j.scala.Booster
class XGBoostModel (booster: Booster) extends Serializable {
}

2
rabit

@ -1 +1 @@
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043