Updated flink 1.8 -> 1.17. Added smoke tests for Flink (#9046)
This commit is contained in:
parent
a320b402a5
commit
0e7377ba9c
4
.github/workflows/jvm_tests.yml
vendored
4
.github/workflows/jvm_tests.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
|||||||
key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
|
key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
|
||||||
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
|
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
|
||||||
|
|
||||||
- name: Test XGBoost4J
|
- name: Test XGBoost4J (Core)
|
||||||
run: |
|
run: |
|
||||||
cd jvm-packages
|
cd jvm-packages
|
||||||
mvn test -B -pl :xgboost4j_2.12
|
mvn test -B -pl :xgboost4j_2.12
|
||||||
@ -67,7 +67,7 @@ jobs:
|
|||||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
|
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
|
||||||
|
|
||||||
|
|
||||||
- name: Test XGBoost4J-Spark
|
- name: Test XGBoost4J (Core, Spark, Examples)
|
||||||
run: |
|
run: |
|
||||||
rm -rfv build/
|
rm -rfv build/
|
||||||
cd jvm-packages
|
cd jvm-packages
|
||||||
|
|||||||
@ -33,7 +33,7 @@
|
|||||||
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
||||||
<maven.compiler.source>1.8</maven.compiler.source>
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
<maven.compiler.target>1.8</maven.compiler.target>
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
<flink.version>1.8.3</flink.version>
|
<flink.version>1.17.0</flink.version>
|
||||||
<spark.version>3.4.0</spark.version>
|
<spark.version>3.4.0</spark.version>
|
||||||
<scala.version>2.12.17</scala.version>
|
<scala.version>2.12.17</scala.version>
|
||||||
<scala.binary.version>2.12</scala.binary.version>
|
<scala.binary.version>2.12</scala.binary.version>
|
||||||
|
|||||||
@ -26,7 +26,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
||||||
<version>2.0.0-SNAPSHOT</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
@ -37,12 +37,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||||
<version>2.0.0-SNAPSHOT</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-lang3</artifactId>
|
|
||||||
<version>3.12.0</version>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@ -0,0 +1,107 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014-2021 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.java.example.flink;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.flink.api.common.typeinfo.TypeHint;
|
||||||
|
import org.apache.flink.api.common.typeinfo.TypeInformation;
|
||||||
|
import org.apache.flink.api.java.DataSet;
|
||||||
|
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||||
|
import org.apache.flink.api.java.operators.MapOperator;
|
||||||
|
import org.apache.flink.api.java.tuple.Tuple13;
|
||||||
|
import org.apache.flink.api.java.tuple.Tuple2;
|
||||||
|
import org.apache.flink.api.java.utils.DataSetUtils;
|
||||||
|
import org.apache.flink.ml.linalg.DenseVector;
|
||||||
|
import org.apache.flink.ml.linalg.Vector;
|
||||||
|
import org.apache.flink.ml.linalg.Vectors;
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.flink.XGBoost;
|
||||||
|
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
|
||||||
|
|
||||||
|
|
||||||
|
public class DistTrainWithFlinkExample {
|
||||||
|
|
||||||
|
static Tuple2<XGBoostModel, DataSet<Float[]>> runPrediction(
|
||||||
|
ExecutionEnvironment env,
|
||||||
|
java.nio.file.Path trainPath,
|
||||||
|
int percentage) throws Exception {
|
||||||
|
// reading data
|
||||||
|
final DataSet<Tuple2<Long, Tuple2<Vector, Double>>> data =
|
||||||
|
DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
|
||||||
|
final long size = data.count();
|
||||||
|
final long trainCount = Math.round(size * 0.01 * percentage);
|
||||||
|
final DataSet<Tuple2<Vector, Double>> trainData =
|
||||||
|
data
|
||||||
|
.filter(item -> item.f0 < trainCount)
|
||||||
|
.map(t -> t.f1)
|
||||||
|
.returns(TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>(){}));
|
||||||
|
final DataSet<Vector> testData =
|
||||||
|
data
|
||||||
|
.filter(tuple -> tuple.f0 >= trainCount)
|
||||||
|
.map(t -> t.f1.f0)
|
||||||
|
.returns(TypeInformation.of(new TypeHint<Vector>(){}));
|
||||||
|
|
||||||
|
// define parameters
|
||||||
|
HashMap<String, Object> paramMap = new HashMap<String, Object>(3);
|
||||||
|
paramMap.put("eta", 0.1);
|
||||||
|
paramMap.put("max_depth", 2);
|
||||||
|
paramMap.put("objective", "binary:logistic");
|
||||||
|
|
||||||
|
// number of iterations
|
||||||
|
final int round = 2;
|
||||||
|
// train the model
|
||||||
|
XGBoostModel model = XGBoost.train(trainData, paramMap, round);
|
||||||
|
DataSet<Float[]> predTest = model.predict(testData);
|
||||||
|
return new Tuple2<XGBoostModel, DataSet<Float[]>>(model, predTest);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static MapOperator<Tuple13<Double, String, Double, Double, Double, Integer, Integer,
|
||||||
|
Integer, Integer, Integer, Integer, Integer, Integer>,
|
||||||
|
Tuple2<Vector, Double>> parseCsv(ExecutionEnvironment env, Path trainPath) {
|
||||||
|
return env.readCsvFile(trainPath.toString())
|
||||||
|
.ignoreFirstLine()
|
||||||
|
.types(Double.class, String.class, Double.class, Double.class, Double.class,
|
||||||
|
Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
|
||||||
|
Integer.class, Integer.class, Integer.class)
|
||||||
|
.map(DistTrainWithFlinkExample::mapFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Tuple2<Vector, Double> mapFunction(Tuple13<Double, String, Double, Double, Double,
|
||||||
|
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer> tuple) {
|
||||||
|
final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
|
||||||
|
tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
|
||||||
|
if (tuple.f1.contains("inf")) {
|
||||||
|
return new Tuple2<Vector, Double>(dense, 1.0);
|
||||||
|
} else {
|
||||||
|
return new Tuple2<Vector, Double>(dense, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) throws Exception {
|
||||||
|
final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
|
||||||
|
.findFirst().orElse("."));
|
||||||
|
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||||||
|
Tuple2<XGBoostModel, DataSet<Float[]>> tuple2 = runPrediction(
|
||||||
|
env, parentPath.resolve("veterans_lung_cancer.csv"), 70
|
||||||
|
);
|
||||||
|
List<Float[]> list = tuple2.f1.collect();
|
||||||
|
System.out.println(list.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 - 2023 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -15,27 +15,84 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.scala.example.flink
|
package ml.dmlc.xgboost4j.scala.example.flink
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
import java.lang.{Double => JDouble, Long => JLong}
|
||||||
import org.apache.flink.api.scala.{ExecutionEnvironment, _}
|
import java.nio.file.{Path, Paths}
|
||||||
import org.apache.flink.ml.MLUtils
|
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
|
||||||
|
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
|
||||||
|
import org.apache.flink.ml.linalg.{Vector, Vectors}
|
||||||
|
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
|
||||||
|
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
|
||||||
|
import org.apache.flink.api.java.utils.DataSetUtils
|
||||||
|
|
||||||
|
|
||||||
object DistTrainWithFlink {
|
object DistTrainWithFlink {
|
||||||
def main(args: Array[String]) {
|
import scala.jdk.CollectionConverters._
|
||||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
|
||||||
// read trainining data
|
private val testDataTypeHint = TypeInformation.of(classOf[Vector])
|
||||||
val trainData =
|
|
||||||
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
|
||||||
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
|
DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
|
||||||
// define parameters
|
DataSetUtils.zipWithIndex(
|
||||||
val paramMap = List(
|
env
|
||||||
"eta" -> 0.1,
|
.readCsvFile(trainPath.toString)
|
||||||
"max_depth" -> 2,
|
.ignoreFirstLine
|
||||||
"objective" -> "binary:logistic").toMap
|
.types(
|
||||||
|
classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
|
||||||
|
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
|
||||||
|
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
|
||||||
|
)
|
||||||
|
.map((row: Tuple13[Double, String, Double, Double, Double,
|
||||||
|
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
|
||||||
|
val dense = Vectors.dense(row.f2, row.f3, row.f4,
|
||||||
|
row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
|
||||||
|
row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
|
||||||
|
val label = if (row.f1.contains("inf")) {
|
||||||
|
JDouble.valueOf(1.0)
|
||||||
|
} else {
|
||||||
|
JDouble.valueOf(0.0)
|
||||||
|
}
|
||||||
|
new Tuple2[Vector, JDouble](dense, label)
|
||||||
|
})
|
||||||
|
.returns(rowTypeHint)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private[flink] def runPrediction(trainPath: Path, percentage: Int)
|
||||||
|
(implicit env: ExecutionEnvironment):
|
||||||
|
(XGBoostModel, DataSet[Array[Float]]) = {
|
||||||
|
// read training data
|
||||||
|
val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
|
||||||
|
val trainSize = Math.round(0.01 * percentage * data.count())
|
||||||
|
val trainData: DataSet[Tuple2[Vector, JDouble]] =
|
||||||
|
data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
|
||||||
|
|
||||||
|
|
||||||
|
val testData: DataSet[Vector] =
|
||||||
|
data
|
||||||
|
.filter(d => d.f0 >= trainSize)
|
||||||
|
.map(_.f1.f0)
|
||||||
|
.returns(testDataTypeHint)
|
||||||
|
|
||||||
|
val paramMap = mapAsJavaMap(Map(
|
||||||
|
("eta", "0.1".asInstanceOf[AnyRef]),
|
||||||
|
("max_depth", "2"),
|
||||||
|
("objective", "binary:logistic"),
|
||||||
|
("verbosity", "1")
|
||||||
|
))
|
||||||
|
|
||||||
// number of iterations
|
// number of iterations
|
||||||
val round = 2
|
val round = 2
|
||||||
// train the model
|
// train the model
|
||||||
val model = XGBoost.train(trainData, paramMap, round)
|
val model = XGBoost.train(trainData, paramMap, round)
|
||||||
val predTest = model.predict(testData.map{x => x.vector})
|
val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
|
||||||
model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
|
(model, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||||
|
val parentPath = Paths.get(args.headOption.getOrElse("."))
|
||||||
|
val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
|
||||||
|
val list = predTest.collect().asScala
|
||||||
|
println(list.length)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,36 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014-2023 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.java.example.flink
|
||||||
|
|
||||||
|
import org.apache.flink.api.java.ExecutionEnvironment
|
||||||
|
import org.scalatest.Inspectors._
|
||||||
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
import org.scalatest.matchers.should.Matchers._
|
||||||
|
|
||||||
|
import java.nio.file.Paths
|
||||||
|
|
||||||
|
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
|
||||||
|
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||||
|
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||||
|
|
||||||
|
test("Smoke test for scala flink example") {
|
||||||
|
val env = ExecutionEnvironment.createLocalEnvironment(1)
|
||||||
|
val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
|
||||||
|
val results = tuple2.f1.collect()
|
||||||
|
results should have size 41
|
||||||
|
forEvery(results)(item => item should have size 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014-2023 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.scala.example.flink
|
||||||
|
|
||||||
|
import org.apache.flink.api.java.ExecutionEnvironment
|
||||||
|
import org.scalatest.Inspectors._
|
||||||
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
import org.scalatest.matchers.should.Matchers._
|
||||||
|
|
||||||
|
import java.nio.file.Paths
|
||||||
|
import scala.jdk.CollectionConverters._
|
||||||
|
|
||||||
|
class DistTrainWithFlinkSuite extends AnyFunSuite {
|
||||||
|
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||||
|
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||||
|
|
||||||
|
test("Smoke test for scala flink example") {
|
||||||
|
implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
|
||||||
|
val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
|
||||||
|
val results = result.collect().asScala
|
||||||
|
results should have size 41
|
||||||
|
forEvery(results)(item => item should have size 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -8,8 +8,11 @@
|
|||||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||||
<version>2.0.0-SNAPSHOT</version>
|
<version>2.0.0-SNAPSHOT</version>
|
||||||
</parent>
|
</parent>
|
||||||
<artifactId>xgboost4j-flink_2.12</artifactId>
|
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||||
<version>2.0.0-SNAPSHOT</version>
|
<version>2.0.0-SNAPSHOT</version>
|
||||||
|
<properties>
|
||||||
|
<flink-ml.version>2.2.0</flink-ml.version>
|
||||||
|
</properties>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
<plugin>
|
<plugin>
|
||||||
@ -26,32 +29,22 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||||
<version>2.0.0-SNAPSHOT</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-lang3</artifactId>
|
|
||||||
<version>3.12.0</version>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.flink</groupId>
|
<groupId>org.apache.flink</groupId>
|
||||||
<artifactId>flink-scala_${scala.binary.version}</artifactId>
|
<artifactId>flink-clients</artifactId>
|
||||||
<version>${flink.version}</version>
|
<version>${flink.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.flink</groupId>
|
<groupId>org.apache.flink</groupId>
|
||||||
<artifactId>flink-clients_${scala.binary.version}</artifactId>
|
<artifactId>flink-ml-servable-core</artifactId>
|
||||||
<version>${flink.version}</version>
|
<version>${flink-ml.version}</version>
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.flink</groupId>
|
|
||||||
<artifactId>flink-ml_${scala.binary.version}</artifactId>
|
|
||||||
<version>${flink.version}</version>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.hadoop</groupId>
|
<groupId>org.apache.hadoop</groupId>
|
||||||
<artifactId>hadoop-common</artifactId>
|
<artifactId>hadoop-common</artifactId>
|
||||||
<version>3.3.5</version>
|
<version>${hadoop.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,187 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014-2023 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.java.flink;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.StreamSupport;
|
||||||
|
|
||||||
|
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
|
||||||
|
import org.apache.flink.api.java.DataSet;
|
||||||
|
import org.apache.flink.api.java.tuple.Tuple2;
|
||||||
|
import org.apache.flink.ml.linalg.SparseVector;
|
||||||
|
import org.apache.flink.ml.linalg.Vector;
|
||||||
|
import org.apache.flink.util.Collector;
|
||||||
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.fs.FSDataInputStream;
|
||||||
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
|
import org.apache.hadoop.fs.Path;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
|
import ml.dmlc.xgboost4j.java.Booster;
|
||||||
|
import ml.dmlc.xgboost4j.java.Communicator;
|
||||||
|
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||||
|
import ml.dmlc.xgboost4j.java.RabitTracker;
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||||
|
|
||||||
|
|
||||||
|
public class XGBoost {
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);
|
||||||
|
|
||||||
|
private static class MapFunction
|
||||||
|
extends RichMapPartitionFunction<Tuple2<Vector, Double>, XGBoostModel> {
|
||||||
|
|
||||||
|
private final Map<String, Object> params;
|
||||||
|
private final int round;
|
||||||
|
private final Map<String, String> workerEnvs;
|
||||||
|
|
||||||
|
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
|
||||||
|
this.params = params;
|
||||||
|
this.round = round;
|
||||||
|
this.workerEnvs = workerEnvs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void mapPartition(java.lang.Iterable<Tuple2<Vector, Double>> it,
|
||||||
|
Collector<XGBoostModel> collector) throws XGBoostError {
|
||||||
|
workerEnvs.put(
|
||||||
|
"DMLC_TASK_ID",
|
||||||
|
String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask())
|
||||||
|
);
|
||||||
|
|
||||||
|
if (logger.isInfoEnabled()) {
|
||||||
|
logger.info("start with env: {}", workerEnvs.entrySet().stream()
|
||||||
|
.map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue()))
|
||||||
|
.collect(Collectors.joining(", "))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
final Iterator<LabeledPoint> dataIter =
|
||||||
|
StreamSupport
|
||||||
|
.stream(it.spliterator(), false)
|
||||||
|
.map(VectorToPointMapper.INSTANCE)
|
||||||
|
.iterator();
|
||||||
|
|
||||||
|
if (dataIter.hasNext()) {
|
||||||
|
final DMatrix trainMat = new DMatrix(dataIter, null);
|
||||||
|
int numEarlyStoppingRounds =
|
||||||
|
Optional.ofNullable(params.get("numEarlyStoppingRounds"))
|
||||||
|
.map(x -> Integer.parseInt(x.toString()))
|
||||||
|
.orElse(0);
|
||||||
|
|
||||||
|
final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds);
|
||||||
|
collector.collect(new XGBoostModel(booster));
|
||||||
|
} else {
|
||||||
|
logger.warn("Nothing to train with.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Booster trainBooster(DMatrix trainMat,
|
||||||
|
int numEarlyStoppingRounds) throws XGBoostError {
|
||||||
|
Booster booster;
|
||||||
|
final Map<String, DMatrix> watches =
|
||||||
|
new HashMap<String, DMatrix>() {{ put("train", trainMat); }};
|
||||||
|
try {
|
||||||
|
Communicator.init(workerEnvs);
|
||||||
|
booster = ml.dmlc.xgboost4j.java.XGBoost
|
||||||
|
.train(
|
||||||
|
trainMat,
|
||||||
|
params,
|
||||||
|
round,
|
||||||
|
watches,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
numEarlyStoppingRounds);
|
||||||
|
} catch (XGBoostError xgbException) {
|
||||||
|
final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
|
||||||
|
logger.warn(
|
||||||
|
String.format("XGBooster worker %s has failed due to", identifier),
|
||||||
|
xgbException
|
||||||
|
);
|
||||||
|
throw xgbException;
|
||||||
|
} finally {
|
||||||
|
Communicator.shutdown();
|
||||||
|
}
|
||||||
|
return booster;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class VectorToPointMapper
|
||||||
|
implements Function<Tuple2<Vector, Double>, LabeledPoint> {
|
||||||
|
public static VectorToPointMapper INSTANCE = new VectorToPointMapper();
|
||||||
|
@Override
|
||||||
|
public LabeledPoint apply(Tuple2<Vector, Double> tuple) {
|
||||||
|
final SparseVector vector = tuple.f0.toSparse();
|
||||||
|
final double[] values = vector.values;
|
||||||
|
final int size = values.length;
|
||||||
|
final float[] array = new float[size];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
array[i] = (float) values[i];
|
||||||
|
}
|
||||||
|
return new LabeledPoint(
|
||||||
|
tuple.f1.floatValue(),
|
||||||
|
vector.size(),
|
||||||
|
vector.indices,
|
||||||
|
array);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||||
|
*
|
||||||
|
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||||
|
* @return The loaded model
|
||||||
|
*/
|
||||||
|
public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception {
|
||||||
|
final FileSystem fileSystem = FileSystem.get(new Configuration());
|
||||||
|
final Path f = new Path(modelPath);
|
||||||
|
|
||||||
|
try (FSDataInputStream opened = fileSystem.open(f)) {
|
||||||
|
return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train a xgboost model with link.
|
||||||
|
*
|
||||||
|
* @param dtrain The training data.
|
||||||
|
* @param params XGBoost parameters.
|
||||||
|
* @param numBoostRound Number of rounds to train.
|
||||||
|
*/
|
||||||
|
public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
|
||||||
|
Map<String, Object> params,
|
||||||
|
int numBoostRound) throws Exception {
|
||||||
|
final RabitTracker tracker =
|
||||||
|
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
|
||||||
|
if (tracker.start(0L)) {
|
||||||
|
return dtrain
|
||||||
|
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
|
||||||
|
.reduce((x, y) -> x)
|
||||||
|
.collect()
|
||||||
|
.get(0);
|
||||||
|
} else {
|
||||||
|
throw new Error("Tracker cannot be started");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,136 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014-2023 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.java.flink;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.stream.StreamSupport;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.apache.flink.api.common.functions.MapPartitionFunction;
|
||||||
|
import org.apache.flink.api.java.DataSet;
|
||||||
|
import org.apache.flink.ml.linalg.SparseVector;
|
||||||
|
import org.apache.flink.ml.linalg.Vector;
|
||||||
|
import org.apache.flink.util.Collector;
|
||||||
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
|
import org.apache.hadoop.fs.Path;
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
|
import ml.dmlc.xgboost4j.java.Booster;
|
||||||
|
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||||
|
|
||||||
|
|
||||||
|
public class XGBoostModel implements Serializable {
|
||||||
|
private static final org.slf4j.Logger logger =
|
||||||
|
org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);
|
||||||
|
|
||||||
|
private final Booster booster;
|
||||||
|
private final PredictorFunction predictorFunction;
|
||||||
|
|
||||||
|
|
||||||
|
public XGBoostModel(Booster booster) {
|
||||||
|
this.booster = booster;
|
||||||
|
this.predictorFunction = new PredictorFunction(booster);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the model as a Hadoop filesystem file.
|
||||||
|
*
|
||||||
|
* @param modelPath The model path as in Hadoop path.
|
||||||
|
*/
|
||||||
|
public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
|
||||||
|
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
|
||||||
|
}
|
||||||
|
|
||||||
|
public byte[] toByteArray(String format) throws XGBoostError {
|
||||||
|
return booster.toByteArray(format);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the model as a Hadoop filesystem file.
|
||||||
|
*
|
||||||
|
* @param modelPath The model path as in Hadoop path.
|
||||||
|
* @param format The model format (ubj, json, deprecated)
|
||||||
|
* @throws XGBoostError internal error
|
||||||
|
* @throws IOException save error
|
||||||
|
*/
|
||||||
|
public void saveModelAsHadoopFile(String modelPath, String format)
|
||||||
|
throws IOException, XGBoostError {
|
||||||
|
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* predict with the given DMatrix
|
||||||
|
*
|
||||||
|
* @param testSet the local test set represented as DMatrix
|
||||||
|
* @return prediction result
|
||||||
|
*/
|
||||||
|
public float[][] predict(DMatrix testSet) throws XGBoostError {
|
||||||
|
return booster.predict(testSet, true, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predict given vector dataset.
|
||||||
|
*
|
||||||
|
* @param data The dataset to be predicted.
|
||||||
|
* @return The prediction result.
|
||||||
|
*/
|
||||||
|
public DataSet<Float[]> predict(DataSet<Vector> data) {
|
||||||
|
return data.mapPartition(predictorFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static class PredictorFunction implements MapPartitionFunction<Vector, Float[]> {
|
||||||
|
|
||||||
|
private final Booster booster;
|
||||||
|
|
||||||
|
public PredictorFunction(Booster booster) {
|
||||||
|
this.booster = booster;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void mapPartition(Iterable<Vector> it, Collector<Float[]> out) throws Exception {
|
||||||
|
final Iterator<LabeledPoint> dataIter =
|
||||||
|
StreamSupport.stream(it.spliterator(), false)
|
||||||
|
.map(Vector::toSparse)
|
||||||
|
.map(PredictorFunction::fromVector)
|
||||||
|
.iterator();
|
||||||
|
|
||||||
|
if (dataIter.hasNext()) {
|
||||||
|
final DMatrix data = new DMatrix(dataIter, null);
|
||||||
|
float[][] predictions = booster.predict(data, true, 2);
|
||||||
|
Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
|
||||||
|
} else {
|
||||||
|
logger.debug("Empty partition");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static LabeledPoint fromVector(SparseVector vector) {
|
||||||
|
final int[] index = vector.indices;
|
||||||
|
final double[] value = vector.values;
|
||||||
|
int size = value.length;
|
||||||
|
final float[] values = new float[size];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
values[i] = (float) value[i];
|
||||||
|
}
|
||||||
|
return new LabeledPoint(0.0f, vector.size(), index, values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,99 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.flink
|
|
||||||
|
|
||||||
import scala.collection.JavaConverters.asScalaIteratorConverter
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
|
||||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
|
|
||||||
|
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction
|
|
||||||
import org.apache.flink.api.scala.{DataSet, _}
|
|
||||||
import org.apache.flink.ml.common.LabeledVector
|
|
||||||
import org.apache.flink.util.Collector
|
|
||||||
import org.apache.hadoop.conf.Configuration
|
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
|
||||||
|
|
||||||
object XGBoost {
|
|
||||||
/**
|
|
||||||
* Helper map function to start the job.
|
|
||||||
*
|
|
||||||
* @param workerEnvs
|
|
||||||
*/
|
|
||||||
private class MapFunction(paramMap: Map[String, Any],
|
|
||||||
round: Int,
|
|
||||||
workerEnvs: java.util.Map[String, String])
|
|
||||||
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
|
|
||||||
val logger = LogFactory.getLog(this.getClass)
|
|
||||||
|
|
||||||
def mapPartition(it: java.lang.Iterable[LabeledVector],
|
|
||||||
collector: Collector[XGBoostModel]): Unit = {
|
|
||||||
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
|
||||||
logger.info("start with env" + workerEnvs.toString)
|
|
||||||
Communicator.init(workerEnvs)
|
|
||||||
val mapper = (x: LabeledVector) => {
|
|
||||||
val (index, value) = x.vector.toSeq.unzip
|
|
||||||
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
|
||||||
}
|
|
||||||
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
|
|
||||||
val trainMat = new DMatrix(dataIter, null)
|
|
||||||
val watches = List("train" -> trainMat).toMap
|
|
||||||
val round = 2
|
|
||||||
val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
|
|
||||||
.map(_.toString.toInt).getOrElse(0)
|
|
||||||
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
|
||||||
earlyStoppingRound = numEarlyStoppingRounds)
|
|
||||||
Communicator.shutdown()
|
|
||||||
collector.collect(new XGBoostModel(booster))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val logger = LogFactory.getLog(this.getClass)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
|
||||||
*
|
|
||||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
|
||||||
* @return The loaded model
|
|
||||||
*/
|
|
||||||
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
|
|
||||||
new XGBoostModel(
|
|
||||||
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Train a xgboost model with link.
|
|
||||||
*
|
|
||||||
* @param dtrain The training data.
|
|
||||||
* @param params The parameters to XGBoost.
|
|
||||||
* @param round Number of rounds to train.
|
|
||||||
*/
|
|
||||||
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
|
|
||||||
XGBoostModel = {
|
|
||||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
|
||||||
if (tracker.start(0L)) {
|
|
||||||
dtrain
|
|
||||||
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
|
|
||||||
.reduce((x, y) => x).collect().head
|
|
||||||
} else {
|
|
||||||
throw new Error("Tracker cannot be started")
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,67 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.flink
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
|
||||||
|
|
||||||
import org.apache.flink.api.scala.{DataSet, _}
|
|
||||||
import org.apache.flink.ml.math.Vector
|
|
||||||
import org.apache.hadoop.conf.Configuration
|
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
|
||||||
|
|
||||||
class XGBoostModel (booster: Booster) extends Serializable {
|
|
||||||
/**
|
|
||||||
* Save the model as a Hadoop filesystem file.
|
|
||||||
*
|
|
||||||
* @param modelPath The model path as in Hadoop path.
|
|
||||||
*/
|
|
||||||
def saveModelAsHadoopFile(modelPath: String): Unit = {
|
|
||||||
booster.saveModel(FileSystem
|
|
||||||
.get(new Configuration)
|
|
||||||
.create(new Path(modelPath)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* predict with the given DMatrix
|
|
||||||
* @param testSet the local test set represented as DMatrix
|
|
||||||
* @return prediction result
|
|
||||||
*/
|
|
||||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
|
||||||
booster.predict(testSet, true, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict given vector dataset.
|
|
||||||
*
|
|
||||||
* @param data The dataset to be predicted.
|
|
||||||
* @return The prediction result.
|
|
||||||
*/
|
|
||||||
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
|
|
||||||
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
|
|
||||||
(it: Iterator[Vector]) => {
|
|
||||||
val mapper = (x: Vector) => {
|
|
||||||
val (index, value) = x.toSeq.unzip
|
|
||||||
LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
|
|
||||||
}
|
|
||||||
val dataIter = for (x <- it) yield mapper(x)
|
|
||||||
val dmat = new DMatrix(dataIter, null)
|
|
||||||
this.booster.predict(dmat)
|
|
||||||
}
|
|
||||||
data.mapPartition(predictMap)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Loading…
x
Reference in New Issue
Block a user