diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml
index 8efcdc2ec..a2d8bb69a 100644
--- a/.github/workflows/jvm_tests.yml
+++ b/.github/workflows/jvm_tests.yml
@@ -40,7 +40,7 @@ jobs:
key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
- - name: Test XGBoost4J
+ - name: Test XGBoost4J (Core)
run: |
cd jvm-packages
mvn test -B -pl :xgboost4j_2.12
@@ -67,7 +67,7 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
- - name: Test XGBoost4J-Spark
+ - name: Test XGBoost4J (Core, Spark, Examples)
run: |
rm -rfv build/
cd jvm-packages
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 2aac8b00c..0ee7f0b1a 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -33,7 +33,7 @@
UTF-8
1.8
1.8
- 1.8.3
+ 1.17.0
3.4.0
2.12.17
2.12
diff --git a/jvm-packages/xgboost4j-example/pom.xml b/jvm-packages/xgboost4j-example/pom.xml
index d08e4f409..40c9c72a4 100644
--- a/jvm-packages/xgboost4j-example/pom.xml
+++ b/jvm-packages/xgboost4j-example/pom.xml
@@ -26,7 +26,7 @@
ml.dmlc
xgboost4j-spark_${scala.binary.version}
- 2.0.0-SNAPSHOT
+ ${project.version}
org.apache.spark
@@ -37,12 +37,7 @@
ml.dmlc
xgboost4j-flink_${scala.binary.version}
- 2.0.0-SNAPSHOT
-
-
- org.apache.commons
- commons-lang3
- 3.12.0
+ ${project.version}
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java
new file mode 100644
index 000000000..94e5cdab5
--- /dev/null
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java
@@ -0,0 +1,107 @@
+/*
+ Copyright (c) 2014-2021 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.java.example.flink;
+
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.operators.MapOperator;
+import org.apache.flink.api.java.tuple.Tuple13;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.utils.DataSetUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+
+import ml.dmlc.xgboost4j.java.flink.XGBoost;
+import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
+
+
+public class DistTrainWithFlinkExample {
+
+ static Tuple2> runPrediction(
+ ExecutionEnvironment env,
+ java.nio.file.Path trainPath,
+ int percentage) throws Exception {
+ // reading data
+ final DataSet>> data =
+ DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
+ final long size = data.count();
+ final long trainCount = Math.round(size * 0.01 * percentage);
+ final DataSet> trainData =
+ data
+ .filter(item -> item.f0 < trainCount)
+ .map(t -> t.f1)
+ .returns(TypeInformation.of(new TypeHint>(){}));
+ final DataSet testData =
+ data
+ .filter(tuple -> tuple.f0 >= trainCount)
+ .map(t -> t.f1.f0)
+ .returns(TypeInformation.of(new TypeHint(){}));
+
+ // define parameters
+ HashMap paramMap = new HashMap(3);
+ paramMap.put("eta", 0.1);
+ paramMap.put("max_depth", 2);
+ paramMap.put("objective", "binary:logistic");
+
+ // number of iterations
+ final int round = 2;
+ // train the model
+ XGBoostModel model = XGBoost.train(trainData, paramMap, round);
+ DataSet predTest = model.predict(testData);
+ return new Tuple2>(model, predTest);
+ }
+
+ private static MapOperator,
+ Tuple2> parseCsv(ExecutionEnvironment env, Path trainPath) {
+ return env.readCsvFile(trainPath.toString())
+ .ignoreFirstLine()
+ .types(Double.class, String.class, Double.class, Double.class, Double.class,
+ Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
+ Integer.class, Integer.class, Integer.class)
+ .map(DistTrainWithFlinkExample::mapFunction);
+ }
+
+ private static Tuple2 mapFunction(Tuple13 tuple) {
+ final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
+ tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
+ if (tuple.f1.contains("inf")) {
+ return new Tuple2(dense, 1.0);
+ } else {
+ return new Tuple2(dense, 0.0);
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
+ .findFirst().orElse("."));
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ Tuple2> tuple2 = runPrediction(
+ env, parentPath.resolve("veterans_lung_cancer.csv"), 70
+ );
+ List list = tuple2.f1.collect();
+ System.out.println(list.size());
+ }
+}
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala
index 74b24ac35..cb859f62d 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014 by Contributors
+ Copyright (c) 2014 - 2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -15,27 +15,84 @@
*/
package ml.dmlc.xgboost4j.scala.example.flink
-import ml.dmlc.xgboost4j.scala.flink.XGBoost
-import org.apache.flink.api.scala.{ExecutionEnvironment, _}
-import org.apache.flink.ml.MLUtils
+import java.lang.{Double => JDouble, Long => JLong}
+import java.nio.file.{Path, Paths}
+import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
+import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
+import org.apache.flink.ml.linalg.{Vector, Vectors}
+import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
+import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
+import org.apache.flink.api.java.utils.DataSetUtils
+
object DistTrainWithFlink {
- def main(args: Array[String]) {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- // read trainining data
- val trainData =
- MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
- val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
- // define parameters
- val paramMap = List(
- "eta" -> 0.1,
- "max_depth" -> 2,
- "objective" -> "binary:logistic").toMap
+ import scala.jdk.CollectionConverters._
+ private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
+ private val testDataTypeHint = TypeInformation.of(classOf[Vector])
+
+ private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
+ DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
+ DataSetUtils.zipWithIndex(
+ env
+ .readCsvFile(trainPath.toString)
+ .ignoreFirstLine
+ .types(
+ classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
+ classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
+ classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
+ )
+ .map((row: Tuple13[Double, String, Double, Double, Double,
+ Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
+ val dense = Vectors.dense(row.f2, row.f3, row.f4,
+ row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
+ row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
+ val label = if (row.f1.contains("inf")) {
+ JDouble.valueOf(1.0)
+ } else {
+ JDouble.valueOf(0.0)
+ }
+ new Tuple2[Vector, JDouble](dense, label)
+ })
+ .returns(rowTypeHint)
+ )
+ }
+
+ private[flink] def runPrediction(trainPath: Path, percentage: Int)
+ (implicit env: ExecutionEnvironment):
+ (XGBoostModel, DataSet[Array[Float]]) = {
+ // read training data
+ val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
+ val trainSize = Math.round(0.01 * percentage * data.count())
+ val trainData: DataSet[Tuple2[Vector, JDouble]] =
+ data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
+
+
+ val testData: DataSet[Vector] =
+ data
+ .filter(d => d.f0 >= trainSize)
+ .map(_.f1.f0)
+ .returns(testDataTypeHint)
+
+ val paramMap = mapAsJavaMap(Map(
+ ("eta", "0.1".asInstanceOf[AnyRef]),
+ ("max_depth", "2"),
+ ("objective", "binary:logistic"),
+ ("verbosity", "1")
+ ))
+
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(trainData, paramMap, round)
- val predTest = model.predict(testData.map{x => x.vector})
- model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
+ val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
+ (model, result)
+ }
+
+ def main(args: Array[String]): Unit = {
+ implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val parentPath = Paths.get(args.headOption.getOrElse("."))
+ val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
+ val list = predTest.collect().asScala
+ println(list.length)
}
}
diff --git a/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala
new file mode 100644
index 000000000..b9929639f
--- /dev/null
+++ b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala
@@ -0,0 +1,36 @@
+/*
+ Copyright (c) 2014-2023 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.java.example.flink
+
+import org.apache.flink.api.java.ExecutionEnvironment
+import org.scalatest.Inspectors._
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.matchers.should.Matchers._
+
+import java.nio.file.Paths
+
+class DistTrainWithFlinkExampleTest extends AnyFunSuite {
+ private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
+ private val data = parentPath.resolve("veterans_lung_cancer.csv")
+
+ test("Smoke test for scala flink example") {
+ val env = ExecutionEnvironment.createLocalEnvironment(1)
+ val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
+ val results = tuple2.f1.collect()
+ results should have size 41
+ forEvery(results)(item => item should have size 1)
+ }
+}
diff --git a/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala
new file mode 100644
index 000000000..d9e98d81c
--- /dev/null
+++ b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala
@@ -0,0 +1,37 @@
+/*
+ Copyright (c) 2014-2023 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.scala.example.flink
+
+import org.apache.flink.api.java.ExecutionEnvironment
+import org.scalatest.Inspectors._
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.matchers.should.Matchers._
+
+import java.nio.file.Paths
+import scala.jdk.CollectionConverters._
+
+class DistTrainWithFlinkSuite extends AnyFunSuite {
+ private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
+ private val data = parentPath.resolve("veterans_lung_cancer.csv")
+
+ test("Smoke test for scala flink example") {
+ implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
+ val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
+ val results = result.collect().asScala
+ results should have size 41
+ forEvery(results)(item => item should have size 1)
+ }
+}
diff --git a/jvm-packages/xgboost4j-flink/pom.xml b/jvm-packages/xgboost4j-flink/pom.xml
index b8b757eae..a9a80e29a 100644
--- a/jvm-packages/xgboost4j-flink/pom.xml
+++ b/jvm-packages/xgboost4j-flink/pom.xml
@@ -8,8 +8,11 @@
xgboost-jvm_2.12
2.0.0-SNAPSHOT
- xgboost4j-flink_2.12
+ xgboost4j-flink_${scala.binary.version}
2.0.0-SNAPSHOT
+
+ 2.2.0
+
@@ -26,32 +29,22 @@
ml.dmlc
xgboost4j_${scala.binary.version}
- 2.0.0-SNAPSHOT
-
-
- org.apache.commons
- commons-lang3
- 3.12.0
+ ${project.version}
org.apache.flink
- flink-scala_${scala.binary.version}
+ flink-clients
${flink.version}
org.apache.flink
- flink-clients_${scala.binary.version}
- ${flink.version}
-
-
- org.apache.flink
- flink-ml_${scala.binary.version}
- ${flink.version}
+ flink-ml-servable-core
+ ${flink-ml.version}
org.apache.hadoop
hadoop-common
- 3.3.5
+ ${hadoop.version}
diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java
new file mode 100644
index 000000000..7a5e3ac68
--- /dev/null
+++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java
@@ -0,0 +1,187 @@
+/*
+ Copyright (c) 2014-2023 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.java.flink;
+
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.util.Collector;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import ml.dmlc.xgboost4j.LabeledPoint;
+import ml.dmlc.xgboost4j.java.Booster;
+import ml.dmlc.xgboost4j.java.Communicator;
+import ml.dmlc.xgboost4j.java.DMatrix;
+import ml.dmlc.xgboost4j.java.RabitTracker;
+import ml.dmlc.xgboost4j.java.XGBoostError;
+
+
+public class XGBoost {
+ private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);
+
+ private static class MapFunction
+ extends RichMapPartitionFunction, XGBoostModel> {
+
+ private final Map params;
+ private final int round;
+ private final Map workerEnvs;
+
+ public MapFunction(Map params, int round, Map workerEnvs) {
+ this.params = params;
+ this.round = round;
+ this.workerEnvs = workerEnvs;
+ }
+
+ public void mapPartition(java.lang.Iterable> it,
+ Collector collector) throws XGBoostError {
+ workerEnvs.put(
+ "DMLC_TASK_ID",
+ String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask())
+ );
+
+ if (logger.isInfoEnabled()) {
+ logger.info("start with env: {}", workerEnvs.entrySet().stream()
+ .map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue()))
+ .collect(Collectors.joining(", "))
+ );
+ }
+
+ final Iterator dataIter =
+ StreamSupport
+ .stream(it.spliterator(), false)
+ .map(VectorToPointMapper.INSTANCE)
+ .iterator();
+
+ if (dataIter.hasNext()) {
+ final DMatrix trainMat = new DMatrix(dataIter, null);
+ int numEarlyStoppingRounds =
+ Optional.ofNullable(params.get("numEarlyStoppingRounds"))
+ .map(x -> Integer.parseInt(x.toString()))
+ .orElse(0);
+
+ final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds);
+ collector.collect(new XGBoostModel(booster));
+ } else {
+ logger.warn("Nothing to train with.");
+ }
+ }
+
+ private Booster trainBooster(DMatrix trainMat,
+ int numEarlyStoppingRounds) throws XGBoostError {
+ Booster booster;
+ final Map watches =
+ new HashMap() {{ put("train", trainMat); }};
+ try {
+ Communicator.init(workerEnvs);
+ booster = ml.dmlc.xgboost4j.java.XGBoost
+ .train(
+ trainMat,
+ params,
+ round,
+ watches,
+ null,
+ null,
+ null,
+ numEarlyStoppingRounds);
+ } catch (XGBoostError xgbException) {
+ final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
+ logger.warn(
+ String.format("XGBooster worker %s has failed due to", identifier),
+ xgbException
+ );
+ throw xgbException;
+ } finally {
+ Communicator.shutdown();
+ }
+ return booster;
+ }
+
+ private static class VectorToPointMapper
+ implements Function, LabeledPoint> {
+ public static VectorToPointMapper INSTANCE = new VectorToPointMapper();
+ @Override
+ public LabeledPoint apply(Tuple2 tuple) {
+ final SparseVector vector = tuple.f0.toSparse();
+ final double[] values = vector.values;
+ final int size = values.length;
+ final float[] array = new float[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = (float) values[i];
+ }
+ return new LabeledPoint(
+ tuple.f1.floatValue(),
+ vector.size(),
+ vector.indices,
+ array);
+ }
+ }
+ }
+
+ /**
+ * Load XGBoost model from path, using Hadoop Filesystem API.
+ *
+ * @param modelPath The path that is accessible by hadoop filesystem API.
+ * @return The loaded model
+ */
+ public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception {
+ final FileSystem fileSystem = FileSystem.get(new Configuration());
+ final Path f = new Path(modelPath);
+
+ try (FSDataInputStream opened = fileSystem.open(f)) {
+ return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened));
+ }
+ }
+
+ /**
+ * Train a xgboost model with link.
+ *
+ * @param dtrain The training data.
+ * @param params XGBoost parameters.
+ * @param numBoostRound Number of rounds to train.
+ */
+ public static XGBoostModel train(DataSet> dtrain,
+ Map params,
+ int numBoostRound) throws Exception {
+ final RabitTracker tracker =
+ new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
+ if (tracker.start(0L)) {
+ return dtrain
+ .mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
+ .reduce((x, y) -> x)
+ .collect()
+ .get(0);
+ } else {
+ throw new Error("Tracker cannot be started");
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java
new file mode 100644
index 000000000..03de50482
--- /dev/null
+++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java
@@ -0,0 +1,136 @@
+/*
+ Copyright (c) 2014-2023 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.java.flink;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.stream.StreamSupport;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.util.Collector;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+
+import ml.dmlc.xgboost4j.LabeledPoint;
+import ml.dmlc.xgboost4j.java.Booster;
+import ml.dmlc.xgboost4j.java.DMatrix;
+import ml.dmlc.xgboost4j.java.XGBoostError;
+
+
+public class XGBoostModel implements Serializable {
+ private static final org.slf4j.Logger logger =
+ org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);
+
+ private final Booster booster;
+ private final PredictorFunction predictorFunction;
+
+
+ public XGBoostModel(Booster booster) {
+ this.booster = booster;
+ this.predictorFunction = new PredictorFunction(booster);
+ }
+
+ /**
+ * Save the model as a Hadoop filesystem file.
+ *
+ * @param modelPath The model path as in Hadoop path.
+ */
+ public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
+ booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
+ }
+
+ public byte[] toByteArray(String format) throws XGBoostError {
+ return booster.toByteArray(format);
+ }
+
+ /**
+ * Save the model as a Hadoop filesystem file.
+ *
+ * @param modelPath The model path as in Hadoop path.
+ * @param format The model format (ubj, json, deprecated)
+ * @throws XGBoostError internal error
+ * @throws IOException save error
+ */
+ public void saveModelAsHadoopFile(String modelPath, String format)
+ throws IOException, XGBoostError {
+ booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
+ }
+
+ /**
+ * predict with the given DMatrix
+ *
+ * @param testSet the local test set represented as DMatrix
+ * @return prediction result
+ */
+ public float[][] predict(DMatrix testSet) throws XGBoostError {
+ return booster.predict(testSet, true, 0);
+ }
+
+ /**
+ * Predict given vector dataset.
+ *
+ * @param data The dataset to be predicted.
+ * @return The prediction result.
+ */
+ public DataSet predict(DataSet data) {
+ return data.mapPartition(predictorFunction);
+ }
+
+
+ private static class PredictorFunction implements MapPartitionFunction {
+
+ private final Booster booster;
+
+ public PredictorFunction(Booster booster) {
+ this.booster = booster;
+ }
+
+ @Override
+ public void mapPartition(Iterable it, Collector out) throws Exception {
+ final Iterator dataIter =
+ StreamSupport.stream(it.spliterator(), false)
+ .map(Vector::toSparse)
+ .map(PredictorFunction::fromVector)
+ .iterator();
+
+ if (dataIter.hasNext()) {
+ final DMatrix data = new DMatrix(dataIter, null);
+ float[][] predictions = booster.predict(data, true, 2);
+ Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
+ } else {
+ logger.debug("Empty partition");
+ }
+ }
+
+ private static LabeledPoint fromVector(SparseVector vector) {
+ final int[] index = vector.indices;
+ final double[] value = vector.values;
+ int size = value.length;
+ final float[] values = new float[size];
+ for (int i = 0; i < size; i++) {
+ values[i] = (float) value[i];
+ }
+ return new LabeledPoint(0.0f, vector.size(), index, values);
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala
deleted file mode 100644
index 6878f1865..000000000
--- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.flink
-
-import scala.collection.JavaConverters.asScalaIteratorConverter
-
-import ml.dmlc.xgboost4j.LabeledPoint
-import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
-import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
-
-import org.apache.commons.logging.LogFactory
-import org.apache.flink.api.common.functions.RichMapPartitionFunction
-import org.apache.flink.api.scala.{DataSet, _}
-import org.apache.flink.ml.common.LabeledVector
-import org.apache.flink.util.Collector
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-object XGBoost {
- /**
- * Helper map function to start the job.
- *
- * @param workerEnvs
- */
- private class MapFunction(paramMap: Map[String, Any],
- round: Int,
- workerEnvs: java.util.Map[String, String])
- extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
- val logger = LogFactory.getLog(this.getClass)
-
- def mapPartition(it: java.lang.Iterable[LabeledVector],
- collector: Collector[XGBoostModel]): Unit = {
- workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
- logger.info("start with env" + workerEnvs.toString)
- Communicator.init(workerEnvs)
- val mapper = (x: LabeledVector) => {
- val (index, value) = x.vector.toSeq.unzip
- LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
- }
- val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
- val trainMat = new DMatrix(dataIter, null)
- val watches = List("train" -> trainMat).toMap
- val round = 2
- val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
- .map(_.toString.toInt).getOrElse(0)
- val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
- earlyStoppingRound = numEarlyStoppingRounds)
- Communicator.shutdown()
- collector.collect(new XGBoostModel(booster))
- }
- }
-
- val logger = LogFactory.getLog(this.getClass)
-
- /**
- * Load XGBoost model from path, using Hadoop Filesystem API.
- *
- * @param modelPath The path that is accessible by hadoop filesystem API.
- * @return The loaded model
- */
- def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
- new XGBoostModel(
- XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
- }
-
- /**
- * Train a xgboost model with link.
- *
- * @param dtrain The training data.
- * @param params The parameters to XGBoost.
- * @param round Number of rounds to train.
- */
- def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
- XGBoostModel = {
- val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
- if (tracker.start(0L)) {
- dtrain
- .mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
- .reduce((x, y) => x).collect().head
- } else {
- throw new Error("Tracker cannot be started")
- null
- }
- }
-}
diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala
deleted file mode 100644
index 71b376974..000000000
--- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.flink
-
-import ml.dmlc.xgboost4j.LabeledPoint
-import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
-
-import org.apache.flink.api.scala.{DataSet, _}
-import org.apache.flink.ml.math.Vector
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-class XGBoostModel (booster: Booster) extends Serializable {
- /**
- * Save the model as a Hadoop filesystem file.
- *
- * @param modelPath The model path as in Hadoop path.
- */
- def saveModelAsHadoopFile(modelPath: String): Unit = {
- booster.saveModel(FileSystem
- .get(new Configuration)
- .create(new Path(modelPath)))
- }
-
- /**
- * predict with the given DMatrix
- * @param testSet the local test set represented as DMatrix
- * @return prediction result
- */
- def predict(testSet: DMatrix): Array[Array[Float]] = {
- booster.predict(testSet, true, 0)
- }
-
- /**
- * Predict given vector dataset.
- *
- * @param data The dataset to be predicted.
- * @return The prediction result.
- */
- def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
- val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
- (it: Iterator[Vector]) => {
- val mapper = (x: Vector) => {
- val (index, value) = x.toSeq.unzip
- LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
- }
- val dataIter = for (x <- it) yield mapper(x)
- val dmat = new DMatrix(dataIter, null)
- this.booster.predict(dmat)
- }
- data.mapPartition(predictMap)
- }
-}