From 99dc311f6dd109ae89c3b3efe5048b7b4376ef36 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 5 Mar 2016 17:58:58 -0800 Subject: [PATCH] [FLINK] Make runnable flink --- dmlc-core | 2 +- jvm-packages/test_distributed.sh | 2 +- jvm-packages/xgboost4j-flink/pom.xml | 11 +++ .../scala/ml/dmlc/xgboost4j/flink/Test.scala | 35 +------- .../ml/dmlc/xgboost4j/flink/XGBoost.scala | 83 +++++++++++++++++++ .../dmlc/xgboost4j/flink/XGBoostModel.scala | 23 +++++ rabit | 2 +- 7 files changed, 124 insertions(+), 34 deletions(-) create mode 100644 jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoost.scala create mode 100644 jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoostModel.scala diff --git a/dmlc-core b/dmlc-core index 71360023d..3f6ff43d3 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0 +Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f diff --git a/jvm-packages/test_distributed.sh b/jvm-packages/test_distributed.sh index b17f6a3b3..736034de5 100755 --- a/jvm-packages/test_distributed.sh +++ b/jvm-packages/test_distributed.sh @@ -1,5 +1,5 @@ #!/bin/bash # Simple script to test distributed version, to be deleted later. cd xgboost4j-flink -flink run -c ml.dmlc.xgboost4j.flink.Test -p 4 target/xgboost4j-flink-0.1-jar-with-dependencies.jar +flink run -c ml.dmlc.xgboost4j.flink.Test -p 2 target/xgboost4j-flink-0.1-jar-with-dependencies.jar cd .. diff --git a/jvm-packages/xgboost4j-flink/pom.xml b/jvm-packages/xgboost4j-flink/pom.xml index ae49c2295..7e6ed2b06 100644 --- a/jvm-packages/xgboost4j-flink/pom.xml +++ b/jvm-packages/xgboost4j-flink/pom.xml @@ -10,6 +10,17 @@ xgboost4j-flink 0.1 + + + + org.apache.maven.plugins + maven-assembly-plugin + + false + + + + jar diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala index 1beec66b3..e55e61702 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala @@ -16,10 +16,6 @@ package ml.dmlc.xgboost4j.flink -import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker} -import ml.dmlc.xgboost4j.scala.Booster -import ml.dmlc.xgboost4j.scala.DMatrix -import ml.dmlc.xgboost4j.scala.XGBoost import org.apache.commons.logging.Log import org.apache.commons.logging.LogFactory import org.apache.flink.api.common.functions.RichMapPartitionFunction @@ -30,26 +26,6 @@ import org.apache.flink.ml.common.LabeledVector import org.apache.flink.ml.MLUtils import org.apache.flink.util.Collector -class ScalaMapFunction(workerEnvs: java.util.Map[String, String]) - extends RichMapPartitionFunction[LabeledVector, Booster] { - val log = LogFactory.getLog(this.getClass) - def mapPartition(it : java.lang.Iterable[LabeledVector], collector: Collector[Booster]): Unit = { - workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask)) - log.info("start with env" + workerEnvs.toString) - Rabit.init(workerEnvs) - - val trainMat = new DMatrix("/home/tqchen/github/xgboost/demo/data/agaricus.txt.train") - - val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic").toMap - val watches = List("train" -> trainMat).toMap - val round = 2 - val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null) - Rabit.shutdown() - collector.collect(booster) - } -} - object Test { @@ -57,13 +33,10 @@ object Test { def main(args: Array[String]) { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train") - val tracker = new RabitTracker(data.getExecutionEnvironment.getParallelism) - log.info("start with parallelism" + data.getExecutionEnvironment.getParallelism) - assert(data.getExecutionEnvironment.getParallelism >= 1) - tracker.start() - - val res = data.mapPartition(new ScalaMapFunction(tracker.getWorkerEnvs)).reduce((x, y) => x) - val model = res.collect() + val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic").toMap + val round = 2 + val model = XGBoost.train(paramMap, data, round) log.info(model) } } diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoost.scala new file mode 100644 index 000000000..8f1e8260a --- /dev/null +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoost.scala @@ -0,0 +1,83 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.flink +import scala.collection.JavaConverters.asScalaIteratorConverter; +import ml.dmlc.xgboost4j.LabeledPoint +import ml.dmlc.xgboost4j.java.{RabitTracker, Rabit} +import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala} +import org.apache.commons.logging.LogFactory +import org.apache.flink.api.common.functions.RichMapPartitionFunction +import org.apache.flink.api.scala.DataSet +import org.apache.flink.api.scala._ +import org.apache.flink.ml.common.LabeledVector +import org.apache.flink.util.Collector + +object XGBoost { + /** + * Helper map function to start the job. + * + * @param workerEnvs + */ + private class MapFunction(paramMap: Map[String, AnyRef], + round: Int, + workerEnvs: java.util.Map[String, String]) + extends RichMapPartitionFunction[LabeledVector, XGBoostModel] { + val logger = LogFactory.getLog(this.getClass) + + def mapPartition(it: java.lang.Iterable[LabeledVector], + collector: Collector[XGBoostModel]): Unit = { + workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask)) + logger.info("start with env" + workerEnvs.toString) + Rabit.init(workerEnvs) + val mapper = (x: LabeledVector) => { + val (index, value) = x.vector.toSeq.unzip + LabeledPoint.fromSparseVector(x.label.toFloat, + index.toArray, value.map(z => z.toFloat).toArray) + } + val dataIter = for (x <- it.iterator().asScala) yield mapper(x) + val trainMat = new DMatrix(dataIter, null) + val watches = List("train" -> trainMat).toMap + val round = 2 + val booster = XGBoostScala.train(paramMap, trainMat, round, watches, null, null) + Rabit.shutdown() + collector.collect(new XGBoostModel(booster)) + } + } + + val logger = LogFactory.getLog(this.getClass) + + /** + * Train a xgboost model with link. + * + * @param params The parameters to XGBoost. + * @param dtrain The training data. + * @param round Number of rounds to train. + */ + def train(params: Map[String, AnyRef], + dtrain: DataSet[LabeledVector], + round: Int): XGBoostModel = { + val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism) + if (tracker.start()) { + dtrain + .mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs)) + .reduce((x, y) => x).collect().head + } else { + throw new Error("Tracker cannot be started") + null + } + } +} diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoostModel.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoostModel.scala new file mode 100644 index 000000000..4197bd724 --- /dev/null +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoostModel.scala @@ -0,0 +1,23 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.flink + +import ml.dmlc.xgboost4j.scala.Booster + +class XGBoostModel (booster: Booster) extends Serializable { + +} diff --git a/rabit b/rabit index 1392e9f3d..be50e7b63 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0 +Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043