Updated flink 1.8 -> 1.17. Added smoke tests for Flink (#9046)

This commit is contained in:
Boris 2023-04-26 12:41:11 +02:00 committed by GitHub
parent a320b402a5
commit 0e7377ba9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 591 additions and 209 deletions

View File

@ -40,7 +40,7 @@ jobs:
key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
- name: Test XGBoost4J
- name: Test XGBoost4J (Core)
run: |
cd jvm-packages
mvn test -B -pl :xgboost4j_2.12
@ -67,7 +67,7 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
- name: Test XGBoost4J-Spark
- name: Test XGBoost4J (Core, Spark, Examples)
run: |
rm -rfv build/
cd jvm-packages

View File

@ -33,7 +33,7 @@
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.8.3</flink.version>
<flink.version>1.17.0</flink.version>
<spark.version>3.4.0</spark.version>
<scala.version>2.12.17</scala.version>
<scala.binary.version>2.12</scala.binary.version>

View File

@ -26,7 +26,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
@ -37,12 +37,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
<version>${project.version}</version>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,107 @@
/*
Copyright (c) 2014-2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.example.flink;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple13;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import ml.dmlc.xgboost4j.java.flink.XGBoost;
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
public class DistTrainWithFlinkExample {
static Tuple2<XGBoostModel, DataSet<Float[]>> runPrediction(
ExecutionEnvironment env,
java.nio.file.Path trainPath,
int percentage) throws Exception {
// reading data
final DataSet<Tuple2<Long, Tuple2<Vector, Double>>> data =
DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
final long size = data.count();
final long trainCount = Math.round(size * 0.01 * percentage);
final DataSet<Tuple2<Vector, Double>> trainData =
data
.filter(item -> item.f0 < trainCount)
.map(t -> t.f1)
.returns(TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>(){}));
final DataSet<Vector> testData =
data
.filter(tuple -> tuple.f0 >= trainCount)
.map(t -> t.f1.f0)
.returns(TypeInformation.of(new TypeHint<Vector>(){}));
// define parameters
HashMap<String, Object> paramMap = new HashMap<String, Object>(3);
paramMap.put("eta", 0.1);
paramMap.put("max_depth", 2);
paramMap.put("objective", "binary:logistic");
// number of iterations
final int round = 2;
// train the model
XGBoostModel model = XGBoost.train(trainData, paramMap, round);
DataSet<Float[]> predTest = model.predict(testData);
return new Tuple2<XGBoostModel, DataSet<Float[]>>(model, predTest);
}
private static MapOperator<Tuple13<Double, String, Double, Double, Double, Integer, Integer,
Integer, Integer, Integer, Integer, Integer, Integer>,
Tuple2<Vector, Double>> parseCsv(ExecutionEnvironment env, Path trainPath) {
return env.readCsvFile(trainPath.toString())
.ignoreFirstLine()
.types(Double.class, String.class, Double.class, Double.class, Double.class,
Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
Integer.class, Integer.class, Integer.class)
.map(DistTrainWithFlinkExample::mapFunction);
}
private static Tuple2<Vector, Double> mapFunction(Tuple13<Double, String, Double, Double, Double,
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer> tuple) {
final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
if (tuple.f1.contains("inf")) {
return new Tuple2<Vector, Double>(dense, 1.0);
} else {
return new Tuple2<Vector, Double>(dense, 0.0);
}
}
public static void main(String[] args) throws Exception {
final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
.findFirst().orElse("."));
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
Tuple2<XGBoostModel, DataSet<Float[]>> tuple2 = runPrediction(
env, parentPath.resolve("veterans_lung_cancer.csv"), 70
);
List<Float[]> list = tuple2.f1.collect();
System.out.println(list.size());
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -15,27 +15,84 @@
*/
package ml.dmlc.xgboost4j.scala.example.flink
import ml.dmlc.xgboost4j.scala.flink.XGBoost
import org.apache.flink.api.scala.{ExecutionEnvironment, _}
import org.apache.flink.ml.MLUtils
import java.lang.{Double => JDouble, Long => JLong}
import java.nio.file.{Path, Paths}
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.ml.linalg.{Vector, Vectors}
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
import org.apache.flink.api.java.utils.DataSetUtils
object DistTrainWithFlink {
def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
// read trainining data
val trainData =
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
import scala.jdk.CollectionConverters._
private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
private val testDataTypeHint = TypeInformation.of(classOf[Vector])
private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
DataSetUtils.zipWithIndex(
env
.readCsvFile(trainPath.toString)
.ignoreFirstLine
.types(
classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
)
.map((row: Tuple13[Double, String, Double, Double, Double,
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
val dense = Vectors.dense(row.f2, row.f3, row.f4,
row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
val label = if (row.f1.contains("inf")) {
JDouble.valueOf(1.0)
} else {
JDouble.valueOf(0.0)
}
new Tuple2[Vector, JDouble](dense, label)
})
.returns(rowTypeHint)
)
}
private[flink] def runPrediction(trainPath: Path, percentage: Int)
(implicit env: ExecutionEnvironment):
(XGBoostModel, DataSet[Array[Float]]) = {
// read training data
val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
val trainSize = Math.round(0.01 * percentage * data.count())
val trainData: DataSet[Tuple2[Vector, JDouble]] =
data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
val testData: DataSet[Vector] =
data
.filter(d => d.f0 >= trainSize)
.map(_.f1.f0)
.returns(testDataTypeHint)
val paramMap = mapAsJavaMap(Map(
("eta", "0.1".asInstanceOf[AnyRef]),
("max_depth", "2"),
("objective", "binary:logistic"),
("verbosity", "1")
))
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(trainData, paramMap, round)
val predTest = model.predict(testData.map{x => x.vector})
model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
(model, result)
}
def main(args: Array[String]): Unit = {
implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val parentPath = Paths.get(args.headOption.getOrElse("."))
val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
val list = predTest.collect().asScala
println(list.length)
}
}

View File

@ -0,0 +1,36 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.example.flink
import org.apache.flink.api.java.ExecutionEnvironment
import org.scalatest.Inspectors._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers._
import java.nio.file.Paths
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
private val data = parentPath.resolve("veterans_lung_cancer.csv")
test("Smoke test for scala flink example") {
val env = ExecutionEnvironment.createLocalEnvironment(1)
val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
val results = tuple2.f1.collect()
results should have size 41
forEvery(results)(item => item should have size 1)
}
}

View File

@ -0,0 +1,37 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example.flink
import org.apache.flink.api.java.ExecutionEnvironment
import org.scalatest.Inspectors._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers._
import java.nio.file.Paths
import scala.jdk.CollectionConverters._
class DistTrainWithFlinkSuite extends AnyFunSuite {
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
private val data = parentPath.resolve("veterans_lung_cancer.csv")
test("Smoke test for scala flink example") {
implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
val results = result.collect().asScala
results should have size 41
forEvery(results)(item => item should have size 1)
}
}

View File

@ -8,8 +8,11 @@
<artifactId>xgboost-jvm_2.12</artifactId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<artifactId>xgboost4j-flink_2.12</artifactId>
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
<properties>
<flink-ml.version>2.2.0</flink-ml.version>
</properties>
<build>
<plugins>
<plugin>
@ -26,32 +29,22 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-scala_${scala.binary.version}</artifactId>
<artifactId>flink-clients</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-clients_${scala.binary.version}</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-ml_${scala.binary.version}</artifactId>
<version>${flink.version}</version>
<artifactId>flink-ml-servable-core</artifactId>
<version>${flink-ml.version}</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>3.3.5</version>
<version>${hadoop.version}</version>
</dependency>
</dependencies>

View File

@ -0,0 +1,187 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.flink;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.RabitTracker;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class XGBoost {
private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);
private static class MapFunction
extends RichMapPartitionFunction<Tuple2<Vector, Double>, XGBoostModel> {
private final Map<String, Object> params;
private final int round;
private final Map<String, String> workerEnvs;
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
this.params = params;
this.round = round;
this.workerEnvs = workerEnvs;
}
public void mapPartition(java.lang.Iterable<Tuple2<Vector, Double>> it,
Collector<XGBoostModel> collector) throws XGBoostError {
workerEnvs.put(
"DMLC_TASK_ID",
String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask())
);
if (logger.isInfoEnabled()) {
logger.info("start with env: {}", workerEnvs.entrySet().stream()
.map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue()))
.collect(Collectors.joining(", "))
);
}
final Iterator<LabeledPoint> dataIter =
StreamSupport
.stream(it.spliterator(), false)
.map(VectorToPointMapper.INSTANCE)
.iterator();
if (dataIter.hasNext()) {
final DMatrix trainMat = new DMatrix(dataIter, null);
int numEarlyStoppingRounds =
Optional.ofNullable(params.get("numEarlyStoppingRounds"))
.map(x -> Integer.parseInt(x.toString()))
.orElse(0);
final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds);
collector.collect(new XGBoostModel(booster));
} else {
logger.warn("Nothing to train with.");
}
}
private Booster trainBooster(DMatrix trainMat,
int numEarlyStoppingRounds) throws XGBoostError {
Booster booster;
final Map<String, DMatrix> watches =
new HashMap<String, DMatrix>() {{ put("train", trainMat); }};
try {
Communicator.init(workerEnvs);
booster = ml.dmlc.xgboost4j.java.XGBoost
.train(
trainMat,
params,
round,
watches,
null,
null,
null,
numEarlyStoppingRounds);
} catch (XGBoostError xgbException) {
final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
logger.warn(
String.format("XGBooster worker %s has failed due to", identifier),
xgbException
);
throw xgbException;
} finally {
Communicator.shutdown();
}
return booster;
}
private static class VectorToPointMapper
implements Function<Tuple2<Vector, Double>, LabeledPoint> {
public static VectorToPointMapper INSTANCE = new VectorToPointMapper();
@Override
public LabeledPoint apply(Tuple2<Vector, Double> tuple) {
final SparseVector vector = tuple.f0.toSparse();
final double[] values = vector.values;
final int size = values.length;
final float[] array = new float[size];
for (int i = 0; i < size; i++) {
array[i] = (float) values[i];
}
return new LabeledPoint(
tuple.f1.floatValue(),
vector.size(),
vector.indices,
array);
}
}
}
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception {
final FileSystem fileSystem = FileSystem.get(new Configuration());
final Path f = new Path(modelPath);
try (FSDataInputStream opened = fileSystem.open(f)) {
return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened));
}
}
/**
* Train a xgboost model with link.
*
* @param dtrain The training data.
* @param params XGBoost parameters.
* @param numBoostRound Number of rounds to train.
*/
public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
Map<String, Object> params,
int numBoostRound) throws Exception {
final RabitTracker tracker =
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start(0L)) {
return dtrain
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
.reduce((x, y) -> x)
.collect()
.get(0);
} else {
throw new Error("Tracker cannot be started");
}
}
}

View File

@ -0,0 +1,136 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.flink;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class XGBoostModel implements Serializable {
private static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);
private final Booster booster;
private final PredictorFunction predictorFunction;
public XGBoostModel(Booster booster) {
this.booster = booster;
this.predictorFunction = new PredictorFunction(booster);
}
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
*/
public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
}
public byte[] toByteArray(String format) throws XGBoostError {
return booster.toByteArray(format);
}
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
* @param format The model format (ubj, json, deprecated)
* @throws XGBoostError internal error
* @throws IOException save error
*/
public void saveModelAsHadoopFile(String modelPath, String format)
throws IOException, XGBoostError {
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
}
/**
* predict with the given DMatrix
*
* @param testSet the local test set represented as DMatrix
* @return prediction result
*/
public float[][] predict(DMatrix testSet) throws XGBoostError {
return booster.predict(testSet, true, 0);
}
/**
* Predict given vector dataset.
*
* @param data The dataset to be predicted.
* @return The prediction result.
*/
public DataSet<Float[]> predict(DataSet<Vector> data) {
return data.mapPartition(predictorFunction);
}
private static class PredictorFunction implements MapPartitionFunction<Vector, Float[]> {
private final Booster booster;
public PredictorFunction(Booster booster) {
this.booster = booster;
}
@Override
public void mapPartition(Iterable<Vector> it, Collector<Float[]> out) throws Exception {
final Iterator<LabeledPoint> dataIter =
StreamSupport.stream(it.spliterator(), false)
.map(Vector::toSparse)
.map(PredictorFunction::fromVector)
.iterator();
if (dataIter.hasNext()) {
final DMatrix data = new DMatrix(dataIter, null);
float[][] predictions = booster.predict(data, true, 2);
Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
} else {
logger.debug("Empty partition");
}
}
private static LabeledPoint fromVector(SparseVector vector) {
final int[] index = vector.indices;
final double[] value = vector.values;
int size = value.length;
final float[] values = new float[size];
for (int i = 0; i < size; i++) {
values[i] = (float) value[i];
}
return new LabeledPoint(0.0f, vector.size(), index, values);
}
}
}

View File

@ -1,99 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.flink
import scala.collection.JavaConverters.asScalaIteratorConverter
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
import org.apache.commons.logging.LogFactory
import org.apache.flink.api.common.functions.RichMapPartitionFunction
import org.apache.flink.api.scala.{DataSet, _}
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.util.Collector
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
object XGBoost {
/**
* Helper map function to start the job.
*
* @param workerEnvs
*/
private class MapFunction(paramMap: Map[String, Any],
round: Int,
workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
val logger = LogFactory.getLog(this.getClass)
def mapPartition(it: java.lang.Iterable[LabeledVector],
collector: Collector[XGBoostModel]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
logger.info("start with env" + workerEnvs.toString)
Communicator.init(workerEnvs)
val mapper = (x: LabeledVector) => {
val (index, value) = x.vector.toSeq.unzip
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
}
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
val trainMat = new DMatrix(dataIter, null)
val watches = List("train" -> trainMat).toMap
val round = 2
val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
.map(_.toString.toInt).getOrElse(0)
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
earlyStoppingRound = numEarlyStoppingRounds)
Communicator.shutdown()
collector.collect(new XGBoostModel(booster))
}
}
val logger = LogFactory.getLog(this.getClass)
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
new XGBoostModel(
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
}
/**
* Train a xgboost model with link.
*
* @param dtrain The training data.
* @param params The parameters to XGBoost.
* @param round Number of rounds to train.
*/
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
if (tracker.start(0L)) {
dtrain
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
.reduce((x, y) => x).collect().head
} else {
throw new Error("Tracker cannot be started")
null
}
}
}

View File

@ -1,67 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.flink
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.flink.api.scala.{DataSet, _}
import org.apache.flink.ml.math.Vector
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
class XGBoostModel (booster: Booster) extends Serializable {
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModelAsHadoopFile(modelPath: String): Unit = {
booster.saveModel(FileSystem
.get(new Configuration)
.create(new Path(modelPath)))
}
/**
* predict with the given DMatrix
* @param testSet the local test set represented as DMatrix
* @return prediction result
*/
def predict(testSet: DMatrix): Array[Array[Float]] = {
booster.predict(testSet, true, 0)
}
/**
* Predict given vector dataset.
*
* @param data The dataset to be predicted.
* @return The prediction result.
*/
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
(it: Iterator[Vector]) => {
val mapper = (x: Vector) => {
val (index, value) = x.toSeq.unzip
LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
}
val dataIter = for (x <- it) yield mapper(x)
val dmat = new DMatrix(dataIter, null)
this.booster.predict(dmat)
}
data.mapPartition(predictMap)
}
}