Updated flink 1.8 -> 1.17. Added smoke tests for Flink (#9046)
This commit is contained in:
parent
a320b402a5
commit
0e7377ba9c
4
.github/workflows/jvm_tests.yml
vendored
4
.github/workflows/jvm_tests.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
|
||||
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
|
||||
|
||||
- name: Test XGBoost4J
|
||||
- name: Test XGBoost4J (Core)
|
||||
run: |
|
||||
cd jvm-packages
|
||||
mvn test -B -pl :xgboost4j_2.12
|
||||
@ -67,7 +67,7 @@ jobs:
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
|
||||
|
||||
|
||||
- name: Test XGBoost4J-Spark
|
||||
- name: Test XGBoost4J (Core, Spark, Examples)
|
||||
run: |
|
||||
rm -rfv build/
|
||||
cd jvm-packages
|
||||
|
||||
@ -33,7 +33,7 @@
|
||||
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<flink.version>1.8.3</flink.version>
|
||||
<flink.version>1.17.0</flink.version>
|
||||
<spark.version>3.4.0</spark.version>
|
||||
<scala.version>2.12.17</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
@ -37,12 +37,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.12.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@ -0,0 +1,107 @@
|
||||
/*
|
||||
Copyright (c) 2014-2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.flink;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.flink.api.common.typeinfo.TypeHint;
|
||||
import org.apache.flink.api.common.typeinfo.TypeInformation;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||
import org.apache.flink.api.java.operators.MapOperator;
|
||||
import org.apache.flink.api.java.tuple.Tuple13;
|
||||
import org.apache.flink.api.java.tuple.Tuple2;
|
||||
import org.apache.flink.api.java.utils.DataSetUtils;
|
||||
import org.apache.flink.ml.linalg.DenseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.ml.linalg.Vectors;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.flink.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
|
||||
|
||||
|
||||
public class DistTrainWithFlinkExample {
|
||||
|
||||
static Tuple2<XGBoostModel, DataSet<Float[]>> runPrediction(
|
||||
ExecutionEnvironment env,
|
||||
java.nio.file.Path trainPath,
|
||||
int percentage) throws Exception {
|
||||
// reading data
|
||||
final DataSet<Tuple2<Long, Tuple2<Vector, Double>>> data =
|
||||
DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
|
||||
final long size = data.count();
|
||||
final long trainCount = Math.round(size * 0.01 * percentage);
|
||||
final DataSet<Tuple2<Vector, Double>> trainData =
|
||||
data
|
||||
.filter(item -> item.f0 < trainCount)
|
||||
.map(t -> t.f1)
|
||||
.returns(TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>(){}));
|
||||
final DataSet<Vector> testData =
|
||||
data
|
||||
.filter(tuple -> tuple.f0 >= trainCount)
|
||||
.map(t -> t.f1.f0)
|
||||
.returns(TypeInformation.of(new TypeHint<Vector>(){}));
|
||||
|
||||
// define parameters
|
||||
HashMap<String, Object> paramMap = new HashMap<String, Object>(3);
|
||||
paramMap.put("eta", 0.1);
|
||||
paramMap.put("max_depth", 2);
|
||||
paramMap.put("objective", "binary:logistic");
|
||||
|
||||
// number of iterations
|
||||
final int round = 2;
|
||||
// train the model
|
||||
XGBoostModel model = XGBoost.train(trainData, paramMap, round);
|
||||
DataSet<Float[]> predTest = model.predict(testData);
|
||||
return new Tuple2<XGBoostModel, DataSet<Float[]>>(model, predTest);
|
||||
}
|
||||
|
||||
private static MapOperator<Tuple13<Double, String, Double, Double, Double, Integer, Integer,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer>,
|
||||
Tuple2<Vector, Double>> parseCsv(ExecutionEnvironment env, Path trainPath) {
|
||||
return env.readCsvFile(trainPath.toString())
|
||||
.ignoreFirstLine()
|
||||
.types(Double.class, String.class, Double.class, Double.class, Double.class,
|
||||
Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
|
||||
Integer.class, Integer.class, Integer.class)
|
||||
.map(DistTrainWithFlinkExample::mapFunction);
|
||||
}
|
||||
|
||||
private static Tuple2<Vector, Double> mapFunction(Tuple13<Double, String, Double, Double, Double,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer> tuple) {
|
||||
final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
|
||||
tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
|
||||
if (tuple.f1.contains("inf")) {
|
||||
return new Tuple2<Vector, Double>(dense, 1.0);
|
||||
} else {
|
||||
return new Tuple2<Vector, Double>(dense, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
|
||||
.findFirst().orElse("."));
|
||||
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||||
Tuple2<XGBoostModel, DataSet<Float[]>> tuple2 = runPrediction(
|
||||
env, parentPath.resolve("veterans_lung_cancer.csv"), 70
|
||||
);
|
||||
List<Float[]> list = tuple2.f1.collect();
|
||||
System.out.println(list.size());
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -15,27 +15,84 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||
import org.apache.flink.api.scala.{ExecutionEnvironment, _}
|
||||
import org.apache.flink.ml.MLUtils
|
||||
import java.lang.{Double => JDouble, Long => JLong}
|
||||
import java.nio.file.{Path, Paths}
|
||||
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
|
||||
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
|
||||
import org.apache.flink.ml.linalg.{Vector, Vectors}
|
||||
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
|
||||
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
|
||||
import org.apache.flink.api.java.utils.DataSetUtils
|
||||
|
||||
|
||||
object DistTrainWithFlink {
|
||||
def main(args: Array[String]) {
|
||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
// read trainining data
|
||||
val trainData =
|
||||
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
||||
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
|
||||
// define parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
import scala.jdk.CollectionConverters._
|
||||
private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
|
||||
private val testDataTypeHint = TypeInformation.of(classOf[Vector])
|
||||
|
||||
private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
|
||||
DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
|
||||
DataSetUtils.zipWithIndex(
|
||||
env
|
||||
.readCsvFile(trainPath.toString)
|
||||
.ignoreFirstLine
|
||||
.types(
|
||||
classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
|
||||
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
|
||||
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
|
||||
)
|
||||
.map((row: Tuple13[Double, String, Double, Double, Double,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
|
||||
val dense = Vectors.dense(row.f2, row.f3, row.f4,
|
||||
row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
|
||||
row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
|
||||
val label = if (row.f1.contains("inf")) {
|
||||
JDouble.valueOf(1.0)
|
||||
} else {
|
||||
JDouble.valueOf(0.0)
|
||||
}
|
||||
new Tuple2[Vector, JDouble](dense, label)
|
||||
})
|
||||
.returns(rowTypeHint)
|
||||
)
|
||||
}
|
||||
|
||||
private[flink] def runPrediction(trainPath: Path, percentage: Int)
|
||||
(implicit env: ExecutionEnvironment):
|
||||
(XGBoostModel, DataSet[Array[Float]]) = {
|
||||
// read training data
|
||||
val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
|
||||
val trainSize = Math.round(0.01 * percentage * data.count())
|
||||
val trainData: DataSet[Tuple2[Vector, JDouble]] =
|
||||
data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
|
||||
|
||||
|
||||
val testData: DataSet[Vector] =
|
||||
data
|
||||
.filter(d => d.f0 >= trainSize)
|
||||
.map(_.f1.f0)
|
||||
.returns(testDataTypeHint)
|
||||
|
||||
val paramMap = mapAsJavaMap(Map(
|
||||
("eta", "0.1".asInstanceOf[AnyRef]),
|
||||
("max_depth", "2"),
|
||||
("objective", "binary:logistic"),
|
||||
("verbosity", "1")
|
||||
))
|
||||
|
||||
// number of iterations
|
||||
val round = 2
|
||||
// train the model
|
||||
val model = XGBoost.train(trainData, paramMap, round)
|
||||
val predTest = model.predict(testData.map{x => x.vector})
|
||||
model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
|
||||
val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
|
||||
(model, result)
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
val parentPath = Paths.get(args.headOption.getOrElse("."))
|
||||
val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
|
||||
val list = predTest.collect().asScala
|
||||
println(list.length)
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.flink
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||
|
||||
test("Smoke test for scala flink example") {
|
||||
val env = ExecutionEnvironment.createLocalEnvironment(1)
|
||||
val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
|
||||
val results = tuple2.f1.collect()
|
||||
results should have size 41
|
||||
forEvery(results)(item => item should have size 1)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
import scala.jdk.CollectionConverters._
|
||||
|
||||
class DistTrainWithFlinkSuite extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||
|
||||
test("Smoke test for scala flink example") {
|
||||
implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
|
||||
val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
|
||||
val results = result.collect().asScala
|
||||
results should have size 41
|
||||
forEvery(results)(item => item should have size 1)
|
||||
}
|
||||
}
|
||||
@ -8,8 +8,11 @@
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-flink_2.12</artifactId>
|
||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<properties>
|
||||
<flink-ml.version>2.2.0</flink-ml.version>
|
||||
</properties>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
@ -26,32 +29,22 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.12.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-scala_${scala.binary.version}</artifactId>
|
||||
<artifactId>flink-clients</artifactId>
|
||||
<version>${flink.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-clients_${scala.binary.version}</artifactId>
|
||||
<version>${flink.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-ml_${scala.binary.version}</artifactId>
|
||||
<version>${flink.version}</version>
|
||||
<artifactId>flink-ml-servable-core</artifactId>
|
||||
<version>${flink-ml.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-common</artifactId>
|
||||
<version>3.3.5</version>
|
||||
<version>${hadoop.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
@ -0,0 +1,187 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java.flink;
|
||||
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.api.java.tuple.Tuple2;
|
||||
import org.apache.flink.ml.linalg.SparseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.util.Collector;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FSDataInputStream;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.Communicator;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.RabitTracker;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
|
||||
public class XGBoost {
|
||||
private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);
|
||||
|
||||
private static class MapFunction
|
||||
extends RichMapPartitionFunction<Tuple2<Vector, Double>, XGBoostModel> {
|
||||
|
||||
private final Map<String, Object> params;
|
||||
private final int round;
|
||||
private final Map<String, String> workerEnvs;
|
||||
|
||||
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
|
||||
this.params = params;
|
||||
this.round = round;
|
||||
this.workerEnvs = workerEnvs;
|
||||
}
|
||||
|
||||
public void mapPartition(java.lang.Iterable<Tuple2<Vector, Double>> it,
|
||||
Collector<XGBoostModel> collector) throws XGBoostError {
|
||||
workerEnvs.put(
|
||||
"DMLC_TASK_ID",
|
||||
String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask())
|
||||
);
|
||||
|
||||
if (logger.isInfoEnabled()) {
|
||||
logger.info("start with env: {}", workerEnvs.entrySet().stream()
|
||||
.map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue()))
|
||||
.collect(Collectors.joining(", "))
|
||||
);
|
||||
}
|
||||
|
||||
final Iterator<LabeledPoint> dataIter =
|
||||
StreamSupport
|
||||
.stream(it.spliterator(), false)
|
||||
.map(VectorToPointMapper.INSTANCE)
|
||||
.iterator();
|
||||
|
||||
if (dataIter.hasNext()) {
|
||||
final DMatrix trainMat = new DMatrix(dataIter, null);
|
||||
int numEarlyStoppingRounds =
|
||||
Optional.ofNullable(params.get("numEarlyStoppingRounds"))
|
||||
.map(x -> Integer.parseInt(x.toString()))
|
||||
.orElse(0);
|
||||
|
||||
final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds);
|
||||
collector.collect(new XGBoostModel(booster));
|
||||
} else {
|
||||
logger.warn("Nothing to train with.");
|
||||
}
|
||||
}
|
||||
|
||||
private Booster trainBooster(DMatrix trainMat,
|
||||
int numEarlyStoppingRounds) throws XGBoostError {
|
||||
Booster booster;
|
||||
final Map<String, DMatrix> watches =
|
||||
new HashMap<String, DMatrix>() {{ put("train", trainMat); }};
|
||||
try {
|
||||
Communicator.init(workerEnvs);
|
||||
booster = ml.dmlc.xgboost4j.java.XGBoost
|
||||
.train(
|
||||
trainMat,
|
||||
params,
|
||||
round,
|
||||
watches,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
numEarlyStoppingRounds);
|
||||
} catch (XGBoostError xgbException) {
|
||||
final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
|
||||
logger.warn(
|
||||
String.format("XGBooster worker %s has failed due to", identifier),
|
||||
xgbException
|
||||
);
|
||||
throw xgbException;
|
||||
} finally {
|
||||
Communicator.shutdown();
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
private static class VectorToPointMapper
|
||||
implements Function<Tuple2<Vector, Double>, LabeledPoint> {
|
||||
public static VectorToPointMapper INSTANCE = new VectorToPointMapper();
|
||||
@Override
|
||||
public LabeledPoint apply(Tuple2<Vector, Double> tuple) {
|
||||
final SparseVector vector = tuple.f0.toSparse();
|
||||
final double[] values = vector.values;
|
||||
final int size = values.length;
|
||||
final float[] array = new float[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
array[i] = (float) values[i];
|
||||
}
|
||||
return new LabeledPoint(
|
||||
tuple.f1.floatValue(),
|
||||
vector.size(),
|
||||
vector.indices,
|
||||
array);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception {
|
||||
final FileSystem fileSystem = FileSystem.get(new Configuration());
|
||||
final Path f = new Path(modelPath);
|
||||
|
||||
try (FSDataInputStream opened = fileSystem.open(f)) {
|
||||
return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a xgboost model with link.
|
||||
*
|
||||
* @param dtrain The training data.
|
||||
* @param params XGBoost parameters.
|
||||
* @param numBoostRound Number of rounds to train.
|
||||
*/
|
||||
public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
|
||||
Map<String, Object> params,
|
||||
int numBoostRound) throws Exception {
|
||||
final RabitTracker tracker =
|
||||
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
|
||||
if (tracker.start(0L)) {
|
||||
return dtrain
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
|
||||
.reduce((x, y) -> x)
|
||||
.collect()
|
||||
.get(0);
|
||||
} else {
|
||||
throw new Error("Tracker cannot be started");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,136 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java.flink;
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.flink.api.common.functions.MapPartitionFunction;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.ml.linalg.SparseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.util.Collector;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
|
||||
public class XGBoostModel implements Serializable {
|
||||
private static final org.slf4j.Logger logger =
|
||||
org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);
|
||||
|
||||
private final Booster booster;
|
||||
private final PredictorFunction predictorFunction;
|
||||
|
||||
|
||||
public XGBoostModel(Booster booster) {
|
||||
this.booster = booster;
|
||||
this.predictorFunction = new PredictorFunction(booster);
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
|
||||
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
|
||||
}
|
||||
|
||||
public byte[] toByteArray(String format) throws XGBoostError {
|
||||
return booster.toByteArray(format);
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
* @param format The model format (ubj, json, deprecated)
|
||||
* @throws XGBoostError internal error
|
||||
* @throws IOException save error
|
||||
*/
|
||||
public void saveModelAsHadoopFile(String modelPath, String format)
|
||||
throws IOException, XGBoostError {
|
||||
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
|
||||
}
|
||||
|
||||
/**
|
||||
* predict with the given DMatrix
|
||||
*
|
||||
* @param testSet the local test set represented as DMatrix
|
||||
* @return prediction result
|
||||
*/
|
||||
public float[][] predict(DMatrix testSet) throws XGBoostError {
|
||||
return booster.predict(testSet, true, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict given vector dataset.
|
||||
*
|
||||
* @param data The dataset to be predicted.
|
||||
* @return The prediction result.
|
||||
*/
|
||||
public DataSet<Float[]> predict(DataSet<Vector> data) {
|
||||
return data.mapPartition(predictorFunction);
|
||||
}
|
||||
|
||||
|
||||
private static class PredictorFunction implements MapPartitionFunction<Vector, Float[]> {
|
||||
|
||||
private final Booster booster;
|
||||
|
||||
public PredictorFunction(Booster booster) {
|
||||
this.booster = booster;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mapPartition(Iterable<Vector> it, Collector<Float[]> out) throws Exception {
|
||||
final Iterator<LabeledPoint> dataIter =
|
||||
StreamSupport.stream(it.spliterator(), false)
|
||||
.map(Vector::toSparse)
|
||||
.map(PredictorFunction::fromVector)
|
||||
.iterator();
|
||||
|
||||
if (dataIter.hasNext()) {
|
||||
final DMatrix data = new DMatrix(dataIter, null);
|
||||
float[][] predictions = booster.predict(data, true, 2);
|
||||
Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
|
||||
} else {
|
||||
logger.debug("Empty partition");
|
||||
}
|
||||
}
|
||||
|
||||
private static LabeledPoint fromVector(SparseVector vector) {
|
||||
final int[] index = vector.indices;
|
||||
final double[] value = vector.values;
|
||||
int size = value.length;
|
||||
final float[] values = new float[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
values[i] = (float) value[i];
|
||||
}
|
||||
return new LabeledPoint(0.0f, vector.size(), index, values);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,99 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.flink
|
||||
|
||||
import scala.collection.JavaConverters.asScalaIteratorConverter
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction
|
||||
import org.apache.flink.api.scala.{DataSet, _}
|
||||
import org.apache.flink.ml.common.LabeledVector
|
||||
import org.apache.flink.util.Collector
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
object XGBoost {
|
||||
/**
|
||||
* Helper map function to start the job.
|
||||
*
|
||||
* @param workerEnvs
|
||||
*/
|
||||
private class MapFunction(paramMap: Map[String, Any],
|
||||
round: Int,
|
||||
workerEnvs: java.util.Map[String, String])
|
||||
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
|
||||
val logger = LogFactory.getLog(this.getClass)
|
||||
|
||||
def mapPartition(it: java.lang.Iterable[LabeledVector],
|
||||
collector: Collector[XGBoostModel]): Unit = {
|
||||
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
||||
logger.info("start with env" + workerEnvs.toString)
|
||||
Communicator.init(workerEnvs)
|
||||
val mapper = (x: LabeledVector) => {
|
||||
val (index, value) = x.vector.toSeq.unzip
|
||||
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
}
|
||||
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
|
||||
val trainMat = new DMatrix(dataIter, null)
|
||||
val watches = List("train" -> trainMat).toMap
|
||||
val round = 2
|
||||
val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
||||
earlyStoppingRound = numEarlyStoppingRounds)
|
||||
Communicator.shutdown()
|
||||
collector.collect(new XGBoostModel(booster))
|
||||
}
|
||||
}
|
||||
|
||||
val logger = LogFactory.getLog(this.getClass)
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
|
||||
new XGBoostModel(
|
||||
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a xgboost model with link.
|
||||
*
|
||||
* @param dtrain The training data.
|
||||
* @param params The parameters to XGBoost.
|
||||
* @param round Number of rounds to train.
|
||||
*/
|
||||
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
|
||||
XGBoostModel = {
|
||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||
if (tracker.start(0L)) {
|
||||
dtrain
|
||||
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
|
||||
.reduce((x, y) => x).collect().head
|
||||
} else {
|
||||
throw new Error("Tracker cannot be started")
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,67 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
|
||||
import org.apache.flink.api.scala.{DataSet, _}
|
||||
import org.apache.flink.ml.math.Vector
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
class XGBoostModel (booster: Booster) extends Serializable {
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModelAsHadoopFile(modelPath: String): Unit = {
|
||||
booster.saveModel(FileSystem
|
||||
.get(new Configuration)
|
||||
.create(new Path(modelPath)))
|
||||
}
|
||||
|
||||
/**
|
||||
* predict with the given DMatrix
|
||||
* @param testSet the local test set represented as DMatrix
|
||||
* @return prediction result
|
||||
*/
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(testSet, true, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict given vector dataset.
|
||||
*
|
||||
* @param data The dataset to be predicted.
|
||||
* @return The prediction result.
|
||||
*/
|
||||
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
|
||||
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
|
||||
(it: Iterator[Vector]) => {
|
||||
val mapper = (x: Vector) => {
|
||||
val (index, value) = x.toSeq.unzip
|
||||
LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
}
|
||||
val dataIter = for (x <- it) yield mapper(x)
|
||||
val dmat = new DMatrix(dataIter, null)
|
||||
this.booster.predict(dmat)
|
||||
}
|
||||
data.mapPartition(predictMap)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user