Updated flink 1.8 -> 1.17. Added smoke tests for Flink (#9046)
This commit is contained in:
@@ -0,0 +1,107 @@
|
||||
/*
|
||||
Copyright (c) 2014-2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.flink;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.flink.api.common.typeinfo.TypeHint;
|
||||
import org.apache.flink.api.common.typeinfo.TypeInformation;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||
import org.apache.flink.api.java.operators.MapOperator;
|
||||
import org.apache.flink.api.java.tuple.Tuple13;
|
||||
import org.apache.flink.api.java.tuple.Tuple2;
|
||||
import org.apache.flink.api.java.utils.DataSetUtils;
|
||||
import org.apache.flink.ml.linalg.DenseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.ml.linalg.Vectors;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.flink.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
|
||||
|
||||
|
||||
public class DistTrainWithFlinkExample {
|
||||
|
||||
static Tuple2<XGBoostModel, DataSet<Float[]>> runPrediction(
|
||||
ExecutionEnvironment env,
|
||||
java.nio.file.Path trainPath,
|
||||
int percentage) throws Exception {
|
||||
// reading data
|
||||
final DataSet<Tuple2<Long, Tuple2<Vector, Double>>> data =
|
||||
DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
|
||||
final long size = data.count();
|
||||
final long trainCount = Math.round(size * 0.01 * percentage);
|
||||
final DataSet<Tuple2<Vector, Double>> trainData =
|
||||
data
|
||||
.filter(item -> item.f0 < trainCount)
|
||||
.map(t -> t.f1)
|
||||
.returns(TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>(){}));
|
||||
final DataSet<Vector> testData =
|
||||
data
|
||||
.filter(tuple -> tuple.f0 >= trainCount)
|
||||
.map(t -> t.f1.f0)
|
||||
.returns(TypeInformation.of(new TypeHint<Vector>(){}));
|
||||
|
||||
// define parameters
|
||||
HashMap<String, Object> paramMap = new HashMap<String, Object>(3);
|
||||
paramMap.put("eta", 0.1);
|
||||
paramMap.put("max_depth", 2);
|
||||
paramMap.put("objective", "binary:logistic");
|
||||
|
||||
// number of iterations
|
||||
final int round = 2;
|
||||
// train the model
|
||||
XGBoostModel model = XGBoost.train(trainData, paramMap, round);
|
||||
DataSet<Float[]> predTest = model.predict(testData);
|
||||
return new Tuple2<XGBoostModel, DataSet<Float[]>>(model, predTest);
|
||||
}
|
||||
|
||||
private static MapOperator<Tuple13<Double, String, Double, Double, Double, Integer, Integer,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer>,
|
||||
Tuple2<Vector, Double>> parseCsv(ExecutionEnvironment env, Path trainPath) {
|
||||
return env.readCsvFile(trainPath.toString())
|
||||
.ignoreFirstLine()
|
||||
.types(Double.class, String.class, Double.class, Double.class, Double.class,
|
||||
Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
|
||||
Integer.class, Integer.class, Integer.class)
|
||||
.map(DistTrainWithFlinkExample::mapFunction);
|
||||
}
|
||||
|
||||
private static Tuple2<Vector, Double> mapFunction(Tuple13<Double, String, Double, Double, Double,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer> tuple) {
|
||||
final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
|
||||
tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
|
||||
if (tuple.f1.contains("inf")) {
|
||||
return new Tuple2<Vector, Double>(dense, 1.0);
|
||||
} else {
|
||||
return new Tuple2<Vector, Double>(dense, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
|
||||
.findFirst().orElse("."));
|
||||
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||||
Tuple2<XGBoostModel, DataSet<Float[]>> tuple2 = runPrediction(
|
||||
env, parentPath.resolve("veterans_lung_cancer.csv"), 70
|
||||
);
|
||||
List<Float[]> list = tuple2.f1.collect();
|
||||
System.out.println(list.size());
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -15,27 +15,84 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||
import org.apache.flink.api.scala.{ExecutionEnvironment, _}
|
||||
import org.apache.flink.ml.MLUtils
|
||||
import java.lang.{Double => JDouble, Long => JLong}
|
||||
import java.nio.file.{Path, Paths}
|
||||
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
|
||||
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
|
||||
import org.apache.flink.ml.linalg.{Vector, Vectors}
|
||||
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
|
||||
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
|
||||
import org.apache.flink.api.java.utils.DataSetUtils
|
||||
|
||||
|
||||
object DistTrainWithFlink {
|
||||
def main(args: Array[String]) {
|
||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
// read trainining data
|
||||
val trainData =
|
||||
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
||||
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
|
||||
// define parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
import scala.jdk.CollectionConverters._
|
||||
private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
|
||||
private val testDataTypeHint = TypeInformation.of(classOf[Vector])
|
||||
|
||||
private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
|
||||
DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
|
||||
DataSetUtils.zipWithIndex(
|
||||
env
|
||||
.readCsvFile(trainPath.toString)
|
||||
.ignoreFirstLine
|
||||
.types(
|
||||
classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
|
||||
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
|
||||
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
|
||||
)
|
||||
.map((row: Tuple13[Double, String, Double, Double, Double,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
|
||||
val dense = Vectors.dense(row.f2, row.f3, row.f4,
|
||||
row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
|
||||
row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
|
||||
val label = if (row.f1.contains("inf")) {
|
||||
JDouble.valueOf(1.0)
|
||||
} else {
|
||||
JDouble.valueOf(0.0)
|
||||
}
|
||||
new Tuple2[Vector, JDouble](dense, label)
|
||||
})
|
||||
.returns(rowTypeHint)
|
||||
)
|
||||
}
|
||||
|
||||
private[flink] def runPrediction(trainPath: Path, percentage: Int)
|
||||
(implicit env: ExecutionEnvironment):
|
||||
(XGBoostModel, DataSet[Array[Float]]) = {
|
||||
// read training data
|
||||
val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
|
||||
val trainSize = Math.round(0.01 * percentage * data.count())
|
||||
val trainData: DataSet[Tuple2[Vector, JDouble]] =
|
||||
data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
|
||||
|
||||
|
||||
val testData: DataSet[Vector] =
|
||||
data
|
||||
.filter(d => d.f0 >= trainSize)
|
||||
.map(_.f1.f0)
|
||||
.returns(testDataTypeHint)
|
||||
|
||||
val paramMap = mapAsJavaMap(Map(
|
||||
("eta", "0.1".asInstanceOf[AnyRef]),
|
||||
("max_depth", "2"),
|
||||
("objective", "binary:logistic"),
|
||||
("verbosity", "1")
|
||||
))
|
||||
|
||||
// number of iterations
|
||||
val round = 2
|
||||
// train the model
|
||||
val model = XGBoost.train(trainData, paramMap, round)
|
||||
val predTest = model.predict(testData.map{x => x.vector})
|
||||
model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
|
||||
val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
|
||||
(model, result)
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
val parentPath = Paths.get(args.headOption.getOrElse("."))
|
||||
val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
|
||||
val list = predTest.collect().asScala
|
||||
println(list.length)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.flink
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||
|
||||
test("Smoke test for scala flink example") {
|
||||
val env = ExecutionEnvironment.createLocalEnvironment(1)
|
||||
val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
|
||||
val results = tuple2.f1.collect()
|
||||
results should have size 41
|
||||
forEvery(results)(item => item should have size 1)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
import scala.jdk.CollectionConverters._
|
||||
|
||||
class DistTrainWithFlinkSuite extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||
|
||||
test("Smoke test for scala flink example") {
|
||||
implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
|
||||
val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
|
||||
val results = result.collect().asScala
|
||||
results should have size 41
|
||||
forEvery(results)(item => item should have size 1)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user