From 0e7377ba9c559597056d4a525a04ffd643b3afd7 Mon Sep 17 00:00:00 2001 From: Boris Date: Wed, 26 Apr 2023 12:41:11 +0200 Subject: [PATCH] Updated flink 1.8 -> 1.17. Added smoke tests for Flink (#9046) --- .github/workflows/jvm_tests.yml | 4 +- jvm-packages/pom.xml | 2 +- jvm-packages/xgboost4j-example/pom.xml | 9 +- .../flink/DistTrainWithFlinkExample.java | 107 ++++++++++ .../example/flink/DistTrainWithFlink.scala | 91 +++++++-- .../flink/DistTrainWithFlinkExampleTest.scala | 36 ++++ .../flink/DistTrainWithFlinkSuite.scala | 37 ++++ jvm-packages/xgboost4j-flink/pom.xml | 25 +-- .../ml/dmlc/xgboost4j/java/flink/XGBoost.java | 187 ++++++++++++++++++ .../xgboost4j/java/flink/XGBoostModel.java | 136 +++++++++++++ .../dmlc/xgboost4j/scala/flink/XGBoost.scala | 99 ---------- .../xgboost4j/scala/flink/XGBoostModel.scala | 67 ------- 12 files changed, 591 insertions(+), 209 deletions(-) create mode 100644 jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java create mode 100644 jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala create mode 100644 jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala create mode 100644 jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java create mode 100644 jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java delete mode 100644 jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala delete mode 100644 jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml index 8efcdc2ec..a2d8bb69a 100644 --- a/.github/workflows/jvm_tests.yml +++ b/.github/workflows/jvm_tests.yml @@ -40,7 +40,7 @@ jobs: key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }} restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }} - - name: Test XGBoost4J + - name: Test XGBoost4J (Core) run: | cd jvm-packages mvn test -B -pl :xgboost4j_2.12 @@ -67,7 +67,7 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }} - - name: Test XGBoost4J-Spark + - name: Test XGBoost4J (Core, Spark, Examples) run: | rm -rfv build/ cd jvm-packages diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 2aac8b00c..0ee7f0b1a 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -33,7 +33,7 @@ UTF-8 1.8 1.8 - 1.8.3 + 1.17.0 3.4.0 2.12.17 2.12 diff --git a/jvm-packages/xgboost4j-example/pom.xml b/jvm-packages/xgboost4j-example/pom.xml index d08e4f409..40c9c72a4 100644 --- a/jvm-packages/xgboost4j-example/pom.xml +++ b/jvm-packages/xgboost4j-example/pom.xml @@ -26,7 +26,7 @@ ml.dmlc xgboost4j-spark_${scala.binary.version} - 2.0.0-SNAPSHOT + ${project.version} org.apache.spark @@ -37,12 +37,7 @@ ml.dmlc xgboost4j-flink_${scala.binary.version} - 2.0.0-SNAPSHOT - - - org.apache.commons - commons-lang3 - 3.12.0 + ${project.version} diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java new file mode 100644 index 000000000..94e5cdab5 --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java @@ -0,0 +1,107 @@ +/* + Copyright (c) 2014-2021 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.java.example.flink; + +import java.nio.file.Path; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.operators.MapOperator; +import org.apache.flink.api.java.tuple.Tuple13; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.utils.DataSetUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; + +import ml.dmlc.xgboost4j.java.flink.XGBoost; +import ml.dmlc.xgboost4j.java.flink.XGBoostModel; + + +public class DistTrainWithFlinkExample { + + static Tuple2> runPrediction( + ExecutionEnvironment env, + java.nio.file.Path trainPath, + int percentage) throws Exception { + // reading data + final DataSet>> data = + DataSetUtils.zipWithIndex(parseCsv(env, trainPath)); + final long size = data.count(); + final long trainCount = Math.round(size * 0.01 * percentage); + final DataSet> trainData = + data + .filter(item -> item.f0 < trainCount) + .map(t -> t.f1) + .returns(TypeInformation.of(new TypeHint>(){})); + final DataSet testData = + data + .filter(tuple -> tuple.f0 >= trainCount) + .map(t -> t.f1.f0) + .returns(TypeInformation.of(new TypeHint(){})); + + // define parameters + HashMap paramMap = new HashMap(3); + paramMap.put("eta", 0.1); + paramMap.put("max_depth", 2); + paramMap.put("objective", "binary:logistic"); + + // number of iterations + final int round = 2; + // train the model + XGBoostModel model = XGBoost.train(trainData, paramMap, round); + DataSet predTest = model.predict(testData); + return new Tuple2>(model, predTest); + } + + private static MapOperator, + Tuple2> parseCsv(ExecutionEnvironment env, Path trainPath) { + return env.readCsvFile(trainPath.toString()) + .ignoreFirstLine() + .types(Double.class, String.class, Double.class, Double.class, Double.class, + Integer.class, Integer.class, Integer.class, Integer.class, Integer.class, + Integer.class, Integer.class, Integer.class) + .map(DistTrainWithFlinkExample::mapFunction); + } + + private static Tuple2 mapFunction(Tuple13 tuple) { + final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6, + tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12); + if (tuple.f1.contains("inf")) { + return new Tuple2(dense, 1.0); + } else { + return new Tuple2(dense, 0.0); + } + } + + public static void main(String[] args) throws Exception { + final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args) + .findFirst().orElse(".")); + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + Tuple2> tuple2 = runPrediction( + env, parentPath.resolve("veterans_lung_cancer.csv"), 70 + ); + List list = tuple2.f1.collect(); + System.out.println(list.size()); + } +} diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala index 74b24ac35..cb859f62d 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 - 2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,27 +15,84 @@ */ package ml.dmlc.xgboost4j.scala.example.flink -import ml.dmlc.xgboost4j.scala.flink.XGBoost -import org.apache.flink.api.scala.{ExecutionEnvironment, _} -import org.apache.flink.ml.MLUtils +import java.lang.{Double => JDouble, Long => JLong} +import java.nio.file.{Path, Paths} +import org.apache.flink.api.java.tuple.{Tuple13, Tuple2} +import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} +import org.apache.flink.ml.linalg.{Vector, Vectors} +import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel} +import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation} +import org.apache.flink.api.java.utils.DataSetUtils + object DistTrainWithFlink { - def main(args: Array[String]) { - val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment - // read trainining data - val trainData = - MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train") - val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test") - // define parameters - val paramMap = List( - "eta" -> 0.1, - "max_depth" -> 2, - "objective" -> "binary:logistic").toMap + import scala.jdk.CollectionConverters._ + private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{}) + private val testDataTypeHint = TypeInformation.of(classOf[Vector]) + + private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment): + DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = { + DataSetUtils.zipWithIndex( + env + .readCsvFile(trainPath.toString) + .ignoreFirstLine + .types( + classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double], + classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer], + classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer] + ) + .map((row: Tuple13[Double, String, Double, Double, Double, + Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => { + val dense = Vectors.dense(row.f2, row.f3, row.f4, + row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble, + row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble) + val label = if (row.f1.contains("inf")) { + JDouble.valueOf(1.0) + } else { + JDouble.valueOf(0.0) + } + new Tuple2[Vector, JDouble](dense, label) + }) + .returns(rowTypeHint) + ) + } + + private[flink] def runPrediction(trainPath: Path, percentage: Int) + (implicit env: ExecutionEnvironment): + (XGBoostModel, DataSet[Array[Float]]) = { + // read training data + val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath) + val trainSize = Math.round(0.01 * percentage * data.count()) + val trainData: DataSet[Tuple2[Vector, JDouble]] = + data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint) + + + val testData: DataSet[Vector] = + data + .filter(d => d.f0 >= trainSize) + .map(_.f1.f0) + .returns(testDataTypeHint) + + val paramMap = mapAsJavaMap(Map( + ("eta", "0.1".asInstanceOf[AnyRef]), + ("max_depth", "2"), + ("objective", "binary:logistic"), + ("verbosity", "1") + )) + // number of iterations val round = 2 // train the model val model = XGBoost.train(trainData, paramMap, round) - val predTest = model.predict(testData.map{x => x.vector}) - model.saveModelAsHadoopFile("file:///path/to/xgboost.model") + val result = model.predict(testData).map(prediction => prediction.map(Float.unbox)) + (model, result) + } + + def main(args: Array[String]): Unit = { + implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val parentPath = Paths.get(args.headOption.getOrElse(".")) + val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70) + val list = predTest.collect().asScala + println(list.length) } } diff --git a/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala new file mode 100644 index 000000000..b9929639f --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala @@ -0,0 +1,36 @@ +/* + Copyright (c) 2014-2023 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.java.example.flink + +import org.apache.flink.api.java.ExecutionEnvironment +import org.scalatest.Inspectors._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers._ + +import java.nio.file.Paths + +class DistTrainWithFlinkExampleTest extends AnyFunSuite { + private val parentPath = Paths.get("../../").resolve("demo").resolve("data") + private val data = parentPath.resolve("veterans_lung_cancer.csv") + + test("Smoke test for scala flink example") { + val env = ExecutionEnvironment.createLocalEnvironment(1) + val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70) + val results = tuple2.f1.collect() + results should have size 41 + forEvery(results)(item => item should have size 1) + } +} diff --git a/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala new file mode 100644 index 000000000..d9e98d81c --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala @@ -0,0 +1,37 @@ +/* + Copyright (c) 2014-2023 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.scala.example.flink + +import org.apache.flink.api.java.ExecutionEnvironment +import org.scalatest.Inspectors._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers._ + +import java.nio.file.Paths +import scala.jdk.CollectionConverters._ + +class DistTrainWithFlinkSuite extends AnyFunSuite { + private val parentPath = Paths.get("../../").resolve("demo").resolve("data") + private val data = parentPath.resolve("veterans_lung_cancer.csv") + + test("Smoke test for scala flink example") { + implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1) + val (_, result) = DistTrainWithFlink.runPrediction(data, 70) + val results = result.collect().asScala + results should have size 41 + forEvery(results)(item => item should have size 1) + } +} diff --git a/jvm-packages/xgboost4j-flink/pom.xml b/jvm-packages/xgboost4j-flink/pom.xml index b8b757eae..a9a80e29a 100644 --- a/jvm-packages/xgboost4j-flink/pom.xml +++ b/jvm-packages/xgboost4j-flink/pom.xml @@ -8,8 +8,11 @@ xgboost-jvm_2.12 2.0.0-SNAPSHOT - xgboost4j-flink_2.12 + xgboost4j-flink_${scala.binary.version} 2.0.0-SNAPSHOT + + 2.2.0 + @@ -26,32 +29,22 @@ ml.dmlc xgboost4j_${scala.binary.version} - 2.0.0-SNAPSHOT - - - org.apache.commons - commons-lang3 - 3.12.0 + ${project.version} org.apache.flink - flink-scala_${scala.binary.version} + flink-clients ${flink.version} org.apache.flink - flink-clients_${scala.binary.version} - ${flink.version} - - - org.apache.flink - flink-ml_${scala.binary.version} - ${flink.version} + flink-ml-servable-core + ${flink-ml.version} org.apache.hadoop hadoop-common - 3.3.5 + ${hadoop.version} diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java new file mode 100644 index 000000000..7a5e3ac68 --- /dev/null +++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java @@ -0,0 +1,187 @@ +/* + Copyright (c) 2014-2023 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.java.flink; + + +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.util.Collector; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import ml.dmlc.xgboost4j.LabeledPoint; +import ml.dmlc.xgboost4j.java.Booster; +import ml.dmlc.xgboost4j.java.Communicator; +import ml.dmlc.xgboost4j.java.DMatrix; +import ml.dmlc.xgboost4j.java.RabitTracker; +import ml.dmlc.xgboost4j.java.XGBoostError; + + +public class XGBoost { + private static final Logger logger = LoggerFactory.getLogger(XGBoost.class); + + private static class MapFunction + extends RichMapPartitionFunction, XGBoostModel> { + + private final Map params; + private final int round; + private final Map workerEnvs; + + public MapFunction(Map params, int round, Map workerEnvs) { + this.params = params; + this.round = round; + this.workerEnvs = workerEnvs; + } + + public void mapPartition(java.lang.Iterable> it, + Collector collector) throws XGBoostError { + workerEnvs.put( + "DMLC_TASK_ID", + String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask()) + ); + + if (logger.isInfoEnabled()) { + logger.info("start with env: {}", workerEnvs.entrySet().stream() + .map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue())) + .collect(Collectors.joining(", ")) + ); + } + + final Iterator dataIter = + StreamSupport + .stream(it.spliterator(), false) + .map(VectorToPointMapper.INSTANCE) + .iterator(); + + if (dataIter.hasNext()) { + final DMatrix trainMat = new DMatrix(dataIter, null); + int numEarlyStoppingRounds = + Optional.ofNullable(params.get("numEarlyStoppingRounds")) + .map(x -> Integer.parseInt(x.toString())) + .orElse(0); + + final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds); + collector.collect(new XGBoostModel(booster)); + } else { + logger.warn("Nothing to train with."); + } + } + + private Booster trainBooster(DMatrix trainMat, + int numEarlyStoppingRounds) throws XGBoostError { + Booster booster; + final Map watches = + new HashMap() {{ put("train", trainMat); }}; + try { + Communicator.init(workerEnvs); + booster = ml.dmlc.xgboost4j.java.XGBoost + .train( + trainMat, + params, + round, + watches, + null, + null, + null, + numEarlyStoppingRounds); + } catch (XGBoostError xgbException) { + final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask()); + logger.warn( + String.format("XGBooster worker %s has failed due to", identifier), + xgbException + ); + throw xgbException; + } finally { + Communicator.shutdown(); + } + return booster; + } + + private static class VectorToPointMapper + implements Function, LabeledPoint> { + public static VectorToPointMapper INSTANCE = new VectorToPointMapper(); + @Override + public LabeledPoint apply(Tuple2 tuple) { + final SparseVector vector = tuple.f0.toSparse(); + final double[] values = vector.values; + final int size = values.length; + final float[] array = new float[size]; + for (int i = 0; i < size; i++) { + array[i] = (float) values[i]; + } + return new LabeledPoint( + tuple.f1.floatValue(), + vector.size(), + vector.indices, + array); + } + } + } + + /** + * Load XGBoost model from path, using Hadoop Filesystem API. + * + * @param modelPath The path that is accessible by hadoop filesystem API. + * @return The loaded model + */ + public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception { + final FileSystem fileSystem = FileSystem.get(new Configuration()); + final Path f = new Path(modelPath); + + try (FSDataInputStream opened = fileSystem.open(f)) { + return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened)); + } + } + + /** + * Train a xgboost model with link. + * + * @param dtrain The training data. + * @param params XGBoost parameters. + * @param numBoostRound Number of rounds to train. + */ + public static XGBoostModel train(DataSet> dtrain, + Map params, + int numBoostRound) throws Exception { + final RabitTracker tracker = + new RabitTracker(dtrain.getExecutionEnvironment().getParallelism()); + if (tracker.start(0L)) { + return dtrain + .mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs())) + .reduce((x, y) -> x) + .collect() + .get(0); + } else { + throw new Error("Tracker cannot be started"); + } + } +} diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java new file mode 100644 index 000000000..03de50482 --- /dev/null +++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java @@ -0,0 +1,136 @@ +/* + Copyright (c) 2014-2023 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.java.flink; +import java.io.IOException; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Iterator; +import java.util.stream.StreamSupport; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.util.Collector; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; + +import ml.dmlc.xgboost4j.LabeledPoint; +import ml.dmlc.xgboost4j.java.Booster; +import ml.dmlc.xgboost4j.java.DMatrix; +import ml.dmlc.xgboost4j.java.XGBoostError; + + +public class XGBoostModel implements Serializable { + private static final org.slf4j.Logger logger = + org.slf4j.LoggerFactory.getLogger(XGBoostModel.class); + + private final Booster booster; + private final PredictorFunction predictorFunction; + + + public XGBoostModel(Booster booster) { + this.booster = booster; + this.predictorFunction = new PredictorFunction(booster); + } + + /** + * Save the model as a Hadoop filesystem file. + * + * @param modelPath The model path as in Hadoop path. + */ + public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError { + booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath))); + } + + public byte[] toByteArray(String format) throws XGBoostError { + return booster.toByteArray(format); + } + + /** + * Save the model as a Hadoop filesystem file. + * + * @param modelPath The model path as in Hadoop path. + * @param format The model format (ubj, json, deprecated) + * @throws XGBoostError internal error + * @throws IOException save error + */ + public void saveModelAsHadoopFile(String modelPath, String format) + throws IOException, XGBoostError { + booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format); + } + + /** + * predict with the given DMatrix + * + * @param testSet the local test set represented as DMatrix + * @return prediction result + */ + public float[][] predict(DMatrix testSet) throws XGBoostError { + return booster.predict(testSet, true, 0); + } + + /** + * Predict given vector dataset. + * + * @param data The dataset to be predicted. + * @return The prediction result. + */ + public DataSet predict(DataSet data) { + return data.mapPartition(predictorFunction); + } + + + private static class PredictorFunction implements MapPartitionFunction { + + private final Booster booster; + + public PredictorFunction(Booster booster) { + this.booster = booster; + } + + @Override + public void mapPartition(Iterable it, Collector out) throws Exception { + final Iterator dataIter = + StreamSupport.stream(it.spliterator(), false) + .map(Vector::toSparse) + .map(PredictorFunction::fromVector) + .iterator(); + + if (dataIter.hasNext()) { + final DMatrix data = new DMatrix(dataIter, null); + float[][] predictions = booster.predict(data, true, 2); + Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect); + } else { + logger.debug("Empty partition"); + } + } + + private static LabeledPoint fromVector(SparseVector vector) { + final int[] index = vector.indices; + final double[] value = vector.values; + int size = value.length; + final float[] values = new float[size]; + for (int i = 0; i < size; i++) { + values[i] = (float) value[i]; + } + return new LabeledPoint(0.0f, vector.size(), index, values); + } + } +} diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala deleted file mode 100644 index 6878f1865..000000000 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.flink - -import scala.collection.JavaConverters.asScalaIteratorConverter - -import ml.dmlc.xgboost4j.LabeledPoint -import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker} -import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala} - -import org.apache.commons.logging.LogFactory -import org.apache.flink.api.common.functions.RichMapPartitionFunction -import org.apache.flink.api.scala.{DataSet, _} -import org.apache.flink.ml.common.LabeledVector -import org.apache.flink.util.Collector -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} - -object XGBoost { - /** - * Helper map function to start the job. - * - * @param workerEnvs - */ - private class MapFunction(paramMap: Map[String, Any], - round: Int, - workerEnvs: java.util.Map[String, String]) - extends RichMapPartitionFunction[LabeledVector, XGBoostModel] { - val logger = LogFactory.getLog(this.getClass) - - def mapPartition(it: java.lang.Iterable[LabeledVector], - collector: Collector[XGBoostModel]): Unit = { - workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask)) - logger.info("start with env" + workerEnvs.toString) - Communicator.init(workerEnvs) - val mapper = (x: LabeledVector) => { - val (index, value) = x.vector.toSeq.unzip - LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray) - } - val dataIter = for (x <- it.iterator().asScala) yield mapper(x) - val trainMat = new DMatrix(dataIter, null) - val watches = List("train" -> trainMat).toMap - val round = 2 - val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds") - .map(_.toString.toInt).getOrElse(0) - val booster = XGBoostScala.train(trainMat, paramMap, round, watches, - earlyStoppingRound = numEarlyStoppingRounds) - Communicator.shutdown() - collector.collect(new XGBoostModel(booster)) - } - } - - val logger = LogFactory.getLog(this.getClass) - - /** - * Load XGBoost model from path, using Hadoop Filesystem API. - * - * @param modelPath The path that is accessible by hadoop filesystem API. - * @return The loaded model - */ - def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = { - new XGBoostModel( - XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath)))) - } - - /** - * Train a xgboost model with link. - * - * @param dtrain The training data. - * @param params The parameters to XGBoost. - * @param round Number of rounds to train. - */ - def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int): - XGBoostModel = { - val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism) - if (tracker.start(0L)) { - dtrain - .mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs)) - .reduce((x, y) => x).collect().head - } else { - throw new Error("Tracker cannot be started") - null - } - } -} diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala deleted file mode 100644 index 71b376974..000000000 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.flink - -import ml.dmlc.xgboost4j.LabeledPoint -import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} - -import org.apache.flink.api.scala.{DataSet, _} -import org.apache.flink.ml.math.Vector -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} - -class XGBoostModel (booster: Booster) extends Serializable { - /** - * Save the model as a Hadoop filesystem file. - * - * @param modelPath The model path as in Hadoop path. - */ - def saveModelAsHadoopFile(modelPath: String): Unit = { - booster.saveModel(FileSystem - .get(new Configuration) - .create(new Path(modelPath))) - } - - /** - * predict with the given DMatrix - * @param testSet the local test set represented as DMatrix - * @return prediction result - */ - def predict(testSet: DMatrix): Array[Array[Float]] = { - booster.predict(testSet, true, 0) - } - - /** - * Predict given vector dataset. - * - * @param data The dataset to be predicted. - * @return The prediction result. - */ - def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = { - val predictMap: Iterator[Vector] => Traversable[Array[Float]] = - (it: Iterator[Vector]) => { - val mapper = (x: Vector) => { - val (index, value) = x.toSeq.unzip - LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray) - } - val dataIter = for (x <- it) yield mapper(x) - val dmat = new DMatrix(dataIter, null) - this.booster.predict(dmat) - } - data.mapPartition(predictMap) - } -}