merge 23Mar01
This commit is contained in:
@@ -33,16 +33,16 @@
|
||||
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<flink.version>1.8.3</flink.version>
|
||||
<spark.version>3.1.1</spark.version>
|
||||
<scala.version>2.12.8</scala.version>
|
||||
<flink.version>1.17.0</flink.version>
|
||||
<spark.version>3.4.0</spark.version>
|
||||
<scala.version>2.12.17</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<hadoop.version>3.3.5</hadoop.version>
|
||||
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
|
||||
<log.capi.invocation>OFF</log.capi.invocation>
|
||||
<use.cuda>OFF</use.cuda>
|
||||
<cudf.version>22.12.0</cudf.version>
|
||||
<spark.rapids.version>22.12.0</spark.rapids.version>
|
||||
<cudf.version>23.04.0</cudf.version>
|
||||
<spark.rapids.version>23.04.0</spark.rapids.version>
|
||||
<cudf.classifier>cuda11</cudf.classifier>
|
||||
</properties>
|
||||
<repositories>
|
||||
@@ -374,7 +374,7 @@
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-checkstyle-plugin</artifactId>
|
||||
<version>3.2.1</version>
|
||||
<version>3.2.2</version>
|
||||
<configuration>
|
||||
<configLocation>checkstyle.xml</configLocation>
|
||||
<failOnViolation>true</failOnViolation>
|
||||
@@ -450,7 +450,7 @@
|
||||
<plugins>
|
||||
<plugin>
|
||||
<artifactId>maven-project-info-reports-plugin</artifactId>
|
||||
<version>3.4.2</version>
|
||||
<version>3.4.3</version>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>net.alchim31.maven</groupId>
|
||||
@@ -469,7 +469,7 @@
|
||||
<dependency>
|
||||
<groupId>com.esotericsoftware</groupId>
|
||||
<artifactId>kryo</artifactId>
|
||||
<version>5.4.0</version>
|
||||
<version>5.5.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
@@ -477,11 +477,6 @@
|
||||
<version>${scala.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-reflect</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
@@ -495,13 +490,13 @@
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
<version>3.0.8</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalactic</groupId>
|
||||
<artifactId>scalactic_${scala.binary.version}</artifactId>
|
||||
<version>3.0.8</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
@@ -37,12 +37,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.12.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2021 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -62,8 +62,8 @@ public class BasicWalkThrough {
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
@@ -112,7 +112,8 @@ public class BasicWalkThrough {
|
||||
|
||||
System.out.println("start build dmatrix from csr sparse data ...");
|
||||
//build dmatrix from CSR Sparse Matrix
|
||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
DataLoader.CSRSparseData spData =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
|
||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
DMatrix.SparseType.CSR, 127);
|
||||
|
||||
@@ -32,8 +32,8 @@ public class BoostFromPrediction {
|
||||
System.out.println("start running example to start from a initial prediction");
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -30,7 +30,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
public class CrossValidation {
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
//load train mat
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
|
||||
//set params
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -139,9 +139,9 @@ public class CustomObjective {
|
||||
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
//load train mat (svmlight format)
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
//load valid mat (svmlight format)
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
|
||||
@@ -29,9 +29,9 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader;
|
||||
public class EarlyStopping {
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
DataLoader.CSRSparseData trainCSR =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DataLoader.CSRSparseData testCSR =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test");
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
|
||||
@@ -32,8 +32,8 @@ public class ExternalMemory {
|
||||
//this is the only difference, add a # followed by a cache prefix name
|
||||
//several cache file with the prefix will be generated
|
||||
//currently only support convert from libsvm file
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -32,8 +32,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
public class GeneralizedLinearModel {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
//change booster to gblinear, so that we are fitting a linear model
|
||||
|
||||
@@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
public class PredictFirstNtree {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
public class PredictLeafIndices {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
/*
|
||||
Copyright (c) 2014-2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.flink;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.flink.api.common.typeinfo.TypeHint;
|
||||
import org.apache.flink.api.common.typeinfo.TypeInformation;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||
import org.apache.flink.api.java.operators.MapOperator;
|
||||
import org.apache.flink.api.java.tuple.Tuple13;
|
||||
import org.apache.flink.api.java.tuple.Tuple2;
|
||||
import org.apache.flink.api.java.utils.DataSetUtils;
|
||||
import org.apache.flink.ml.linalg.DenseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.ml.linalg.Vectors;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.flink.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
|
||||
|
||||
|
||||
public class DistTrainWithFlinkExample {
|
||||
|
||||
static Tuple2<XGBoostModel, DataSet<Float[]>> runPrediction(
|
||||
ExecutionEnvironment env,
|
||||
java.nio.file.Path trainPath,
|
||||
int percentage) throws Exception {
|
||||
// reading data
|
||||
final DataSet<Tuple2<Long, Tuple2<Vector, Double>>> data =
|
||||
DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
|
||||
final long size = data.count();
|
||||
final long trainCount = Math.round(size * 0.01 * percentage);
|
||||
final DataSet<Tuple2<Vector, Double>> trainData =
|
||||
data
|
||||
.filter(item -> item.f0 < trainCount)
|
||||
.map(t -> t.f1)
|
||||
.returns(TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>(){}));
|
||||
final DataSet<Vector> testData =
|
||||
data
|
||||
.filter(tuple -> tuple.f0 >= trainCount)
|
||||
.map(t -> t.f1.f0)
|
||||
.returns(TypeInformation.of(new TypeHint<Vector>(){}));
|
||||
|
||||
// define parameters
|
||||
HashMap<String, Object> paramMap = new HashMap<String, Object>(3);
|
||||
paramMap.put("eta", 0.1);
|
||||
paramMap.put("max_depth", 2);
|
||||
paramMap.put("objective", "binary:logistic");
|
||||
|
||||
// number of iterations
|
||||
final int round = 2;
|
||||
// train the model
|
||||
XGBoostModel model = XGBoost.train(trainData, paramMap, round);
|
||||
DataSet<Float[]> predTest = model.predict(testData);
|
||||
return new Tuple2<XGBoostModel, DataSet<Float[]>>(model, predTest);
|
||||
}
|
||||
|
||||
private static MapOperator<Tuple13<Double, String, Double, Double, Double, Integer, Integer,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer>,
|
||||
Tuple2<Vector, Double>> parseCsv(ExecutionEnvironment env, Path trainPath) {
|
||||
return env.readCsvFile(trainPath.toString())
|
||||
.ignoreFirstLine()
|
||||
.types(Double.class, String.class, Double.class, Double.class, Double.class,
|
||||
Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
|
||||
Integer.class, Integer.class, Integer.class)
|
||||
.map(DistTrainWithFlinkExample::mapFunction);
|
||||
}
|
||||
|
||||
private static Tuple2<Vector, Double> mapFunction(Tuple13<Double, String, Double, Double, Double,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer> tuple) {
|
||||
final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
|
||||
tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
|
||||
if (tuple.f1.contains("inf")) {
|
||||
return new Tuple2<Vector, Double>(dense, 1.0);
|
||||
} else {
|
||||
return new Tuple2<Vector, Double>(dense, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
|
||||
.findFirst().orElse("."));
|
||||
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||||
Tuple2<XGBoostModel, DataSet<Float[]>> tuple2 = runPrediction(
|
||||
env, parentPath.resolve("veterans_lung_cancer.csv"), 70
|
||||
);
|
||||
List<Float[]> list = tuple2.f1.collect();
|
||||
System.out.println(list.size());
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -36,8 +36,8 @@ object BasicWalkThrough {
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMax = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
@@ -76,7 +76,7 @@ object BasicWalkThrough {
|
||||
|
||||
// build dmatrix from CSR Sparse Matrix
|
||||
println("start build dmatrix from csr sparse data ...")
|
||||
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
|
||||
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
JDMatrix.SparseType.CSR)
|
||||
trainMax2.setLabel(spData.labels)
|
||||
|
||||
@@ -24,8 +24,8 @@ object BoostFromPrediction {
|
||||
def main(args: Array[String]): Unit = {
|
||||
println("start running example to start from a initial prediction")
|
||||
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
object CrossValidation {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
|
||||
// set params
|
||||
val params = new mutable.HashMap[String, Any]
|
||||
|
||||
@@ -138,8 +138,8 @@ object CustomObjective {
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
params += "max_depth" -> 2
|
||||
|
||||
@@ -25,8 +25,8 @@ object ExternalMemory {
|
||||
// this is the only difference, add a # followed by a cache prefix name
|
||||
// several cache file with the prefix will be generated
|
||||
// currently only support convert from libsvm file
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@@ -27,8 +27,8 @@ import ml.dmlc.xgboost4j.scala.example.util.CustomEval
|
||||
*/
|
||||
object GeneralizedLinearModel {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
// specify parameters
|
||||
// change booster to gblinear, so that we are fitting a linear model
|
||||
|
||||
@@ -23,8 +23,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
object PredictFirstNTree {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
object PredictLeafIndices {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -15,27 +15,84 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||
import org.apache.flink.api.scala.{ExecutionEnvironment, _}
|
||||
import org.apache.flink.ml.MLUtils
|
||||
import java.lang.{Double => JDouble, Long => JLong}
|
||||
import java.nio.file.{Path, Paths}
|
||||
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
|
||||
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
|
||||
import org.apache.flink.ml.linalg.{Vector, Vectors}
|
||||
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
|
||||
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
|
||||
import org.apache.flink.api.java.utils.DataSetUtils
|
||||
|
||||
|
||||
object DistTrainWithFlink {
|
||||
def main(args: Array[String]) {
|
||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
// read trainining data
|
||||
val trainData =
|
||||
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
||||
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
|
||||
// define parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
import scala.jdk.CollectionConverters._
|
||||
private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
|
||||
private val testDataTypeHint = TypeInformation.of(classOf[Vector])
|
||||
|
||||
private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
|
||||
DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
|
||||
DataSetUtils.zipWithIndex(
|
||||
env
|
||||
.readCsvFile(trainPath.toString)
|
||||
.ignoreFirstLine
|
||||
.types(
|
||||
classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
|
||||
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
|
||||
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
|
||||
)
|
||||
.map((row: Tuple13[Double, String, Double, Double, Double,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
|
||||
val dense = Vectors.dense(row.f2, row.f3, row.f4,
|
||||
row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
|
||||
row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
|
||||
val label = if (row.f1.contains("inf")) {
|
||||
JDouble.valueOf(1.0)
|
||||
} else {
|
||||
JDouble.valueOf(0.0)
|
||||
}
|
||||
new Tuple2[Vector, JDouble](dense, label)
|
||||
})
|
||||
.returns(rowTypeHint)
|
||||
)
|
||||
}
|
||||
|
||||
private[flink] def runPrediction(trainPath: Path, percentage: Int)
|
||||
(implicit env: ExecutionEnvironment):
|
||||
(XGBoostModel, DataSet[Array[Float]]) = {
|
||||
// read training data
|
||||
val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
|
||||
val trainSize = Math.round(0.01 * percentage * data.count())
|
||||
val trainData: DataSet[Tuple2[Vector, JDouble]] =
|
||||
data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
|
||||
|
||||
|
||||
val testData: DataSet[Vector] =
|
||||
data
|
||||
.filter(d => d.f0 >= trainSize)
|
||||
.map(_.f1.f0)
|
||||
.returns(testDataTypeHint)
|
||||
|
||||
val paramMap = mapAsJavaMap(Map(
|
||||
("eta", "0.1".asInstanceOf[AnyRef]),
|
||||
("max_depth", "2"),
|
||||
("objective", "binary:logistic"),
|
||||
("verbosity", "1")
|
||||
))
|
||||
|
||||
// number of iterations
|
||||
val round = 2
|
||||
// train the model
|
||||
val model = XGBoost.train(trainData, paramMap, round)
|
||||
val predTest = model.predict(testData.map{x => x.vector})
|
||||
model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
|
||||
val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
|
||||
(model, result)
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
val parentPath = Paths.get(args.headOption.getOrElse("."))
|
||||
val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
|
||||
val list = predTest.collect().asScala
|
||||
println(list.length)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.flink
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||
|
||||
test("Smoke test for scala flink example") {
|
||||
val env = ExecutionEnvironment.createLocalEnvironment(1)
|
||||
val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
|
||||
val results = tuple2.f1.collect()
|
||||
results should have size 41
|
||||
forEvery(results)(item => item should have size 1)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
import scala.jdk.CollectionConverters._
|
||||
|
||||
class DistTrainWithFlinkSuite extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||
|
||||
test("Smoke test for scala flink example") {
|
||||
implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
|
||||
val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
|
||||
val results = result.collect().asScala
|
||||
results should have size 41
|
||||
forEvery(results)(item => item should have size 1)
|
||||
}
|
||||
}
|
||||
@@ -8,8 +8,11 @@
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-flink_2.12</artifactId>
|
||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<properties>
|
||||
<flink-ml.version>2.2.0</flink-ml.version>
|
||||
</properties>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
@@ -26,32 +29,22 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.12.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-scala_${scala.binary.version}</artifactId>
|
||||
<artifactId>flink-clients</artifactId>
|
||||
<version>${flink.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-clients_${scala.binary.version}</artifactId>
|
||||
<version>${flink.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-ml_${scala.binary.version}</artifactId>
|
||||
<version>${flink.version}</version>
|
||||
<artifactId>flink-ml-servable-core</artifactId>
|
||||
<version>${flink-ml.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-common</artifactId>
|
||||
<version>3.3.5</version>
|
||||
<version>${hadoop.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java.flink;
|
||||
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.api.java.tuple.Tuple2;
|
||||
import org.apache.flink.ml.linalg.SparseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.util.Collector;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FSDataInputStream;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.Communicator;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.RabitTracker;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
|
||||
public class XGBoost {
|
||||
private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);
|
||||
|
||||
private static class MapFunction
|
||||
extends RichMapPartitionFunction<Tuple2<Vector, Double>, XGBoostModel> {
|
||||
|
||||
private final Map<String, Object> params;
|
||||
private final int round;
|
||||
private final Map<String, String> workerEnvs;
|
||||
|
||||
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
|
||||
this.params = params;
|
||||
this.round = round;
|
||||
this.workerEnvs = workerEnvs;
|
||||
}
|
||||
|
||||
public void mapPartition(java.lang.Iterable<Tuple2<Vector, Double>> it,
|
||||
Collector<XGBoostModel> collector) throws XGBoostError {
|
||||
workerEnvs.put(
|
||||
"DMLC_TASK_ID",
|
||||
String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask())
|
||||
);
|
||||
|
||||
if (logger.isInfoEnabled()) {
|
||||
logger.info("start with env: {}", workerEnvs.entrySet().stream()
|
||||
.map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue()))
|
||||
.collect(Collectors.joining(", "))
|
||||
);
|
||||
}
|
||||
|
||||
final Iterator<LabeledPoint> dataIter =
|
||||
StreamSupport
|
||||
.stream(it.spliterator(), false)
|
||||
.map(VectorToPointMapper.INSTANCE)
|
||||
.iterator();
|
||||
|
||||
if (dataIter.hasNext()) {
|
||||
final DMatrix trainMat = new DMatrix(dataIter, null);
|
||||
int numEarlyStoppingRounds =
|
||||
Optional.ofNullable(params.get("numEarlyStoppingRounds"))
|
||||
.map(x -> Integer.parseInt(x.toString()))
|
||||
.orElse(0);
|
||||
|
||||
final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds);
|
||||
collector.collect(new XGBoostModel(booster));
|
||||
} else {
|
||||
logger.warn("Nothing to train with.");
|
||||
}
|
||||
}
|
||||
|
||||
private Booster trainBooster(DMatrix trainMat,
|
||||
int numEarlyStoppingRounds) throws XGBoostError {
|
||||
Booster booster;
|
||||
final Map<String, DMatrix> watches =
|
||||
new HashMap<String, DMatrix>() {{ put("train", trainMat); }};
|
||||
try {
|
||||
Communicator.init(workerEnvs);
|
||||
booster = ml.dmlc.xgboost4j.java.XGBoost
|
||||
.train(
|
||||
trainMat,
|
||||
params,
|
||||
round,
|
||||
watches,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
numEarlyStoppingRounds);
|
||||
} catch (XGBoostError xgbException) {
|
||||
final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
|
||||
logger.warn(
|
||||
String.format("XGBooster worker %s has failed due to", identifier),
|
||||
xgbException
|
||||
);
|
||||
throw xgbException;
|
||||
} finally {
|
||||
Communicator.shutdown();
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
private static class VectorToPointMapper
|
||||
implements Function<Tuple2<Vector, Double>, LabeledPoint> {
|
||||
public static VectorToPointMapper INSTANCE = new VectorToPointMapper();
|
||||
@Override
|
||||
public LabeledPoint apply(Tuple2<Vector, Double> tuple) {
|
||||
final SparseVector vector = tuple.f0.toSparse();
|
||||
final double[] values = vector.values;
|
||||
final int size = values.length;
|
||||
final float[] array = new float[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
array[i] = (float) values[i];
|
||||
}
|
||||
return new LabeledPoint(
|
||||
tuple.f1.floatValue(),
|
||||
vector.size(),
|
||||
vector.indices,
|
||||
array);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception {
|
||||
final FileSystem fileSystem = FileSystem.get(new Configuration());
|
||||
final Path f = new Path(modelPath);
|
||||
|
||||
try (FSDataInputStream opened = fileSystem.open(f)) {
|
||||
return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a xgboost model with link.
|
||||
*
|
||||
* @param dtrain The training data.
|
||||
* @param params XGBoost parameters.
|
||||
* @param numBoostRound Number of rounds to train.
|
||||
*/
|
||||
public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
|
||||
Map<String, Object> params,
|
||||
int numBoostRound) throws Exception {
|
||||
final RabitTracker tracker =
|
||||
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
|
||||
if (tracker.start(0L)) {
|
||||
return dtrain
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
|
||||
.reduce((x, y) -> x)
|
||||
.collect()
|
||||
.get(0);
|
||||
} else {
|
||||
throw new Error("Tracker cannot be started");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java.flink;
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.flink.api.common.functions.MapPartitionFunction;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.ml.linalg.SparseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.util.Collector;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
|
||||
public class XGBoostModel implements Serializable {
|
||||
private static final org.slf4j.Logger logger =
|
||||
org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);
|
||||
|
||||
private final Booster booster;
|
||||
private final PredictorFunction predictorFunction;
|
||||
|
||||
|
||||
public XGBoostModel(Booster booster) {
|
||||
this.booster = booster;
|
||||
this.predictorFunction = new PredictorFunction(booster);
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
|
||||
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
|
||||
}
|
||||
|
||||
public byte[] toByteArray(String format) throws XGBoostError {
|
||||
return booster.toByteArray(format);
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
* @param format The model format (ubj, json, deprecated)
|
||||
* @throws XGBoostError internal error
|
||||
* @throws IOException save error
|
||||
*/
|
||||
public void saveModelAsHadoopFile(String modelPath, String format)
|
||||
throws IOException, XGBoostError {
|
||||
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
|
||||
}
|
||||
|
||||
/**
|
||||
* predict with the given DMatrix
|
||||
*
|
||||
* @param testSet the local test set represented as DMatrix
|
||||
* @return prediction result
|
||||
*/
|
||||
public float[][] predict(DMatrix testSet) throws XGBoostError {
|
||||
return booster.predict(testSet, true, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict given vector dataset.
|
||||
*
|
||||
* @param data The dataset to be predicted.
|
||||
* @return The prediction result.
|
||||
*/
|
||||
public DataSet<Float[]> predict(DataSet<Vector> data) {
|
||||
return data.mapPartition(predictorFunction);
|
||||
}
|
||||
|
||||
|
||||
private static class PredictorFunction implements MapPartitionFunction<Vector, Float[]> {
|
||||
|
||||
private final Booster booster;
|
||||
|
||||
public PredictorFunction(Booster booster) {
|
||||
this.booster = booster;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mapPartition(Iterable<Vector> it, Collector<Float[]> out) throws Exception {
|
||||
final Iterator<LabeledPoint> dataIter =
|
||||
StreamSupport.stream(it.spliterator(), false)
|
||||
.map(Vector::toSparse)
|
||||
.map(PredictorFunction::fromVector)
|
||||
.iterator();
|
||||
|
||||
if (dataIter.hasNext()) {
|
||||
final DMatrix data = new DMatrix(dataIter, null);
|
||||
float[][] predictions = booster.predict(data, true, 2);
|
||||
Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
|
||||
} else {
|
||||
logger.debug("Empty partition");
|
||||
}
|
||||
}
|
||||
|
||||
private static LabeledPoint fromVector(SparseVector vector) {
|
||||
final int[] index = vector.indices;
|
||||
final double[] value = vector.values;
|
||||
int size = value.length;
|
||||
final float[] values = new float[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
values[i] = (float) value[i];
|
||||
}
|
||||
return new LabeledPoint(0.0f, vector.size(), index, values);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.flink
|
||||
|
||||
import scala.collection.JavaConverters.asScalaIteratorConverter
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction
|
||||
import org.apache.flink.api.scala.{DataSet, _}
|
||||
import org.apache.flink.ml.common.LabeledVector
|
||||
import org.apache.flink.util.Collector
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
object XGBoost {
|
||||
/**
|
||||
* Helper map function to start the job.
|
||||
*
|
||||
* @param workerEnvs
|
||||
*/
|
||||
private class MapFunction(paramMap: Map[String, Any],
|
||||
round: Int,
|
||||
workerEnvs: java.util.Map[String, String])
|
||||
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
|
||||
val logger = LogFactory.getLog(this.getClass)
|
||||
|
||||
def mapPartition(it: java.lang.Iterable[LabeledVector],
|
||||
collector: Collector[XGBoostModel]): Unit = {
|
||||
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
||||
logger.info("start with env" + workerEnvs.toString)
|
||||
Communicator.init(workerEnvs)
|
||||
val mapper = (x: LabeledVector) => {
|
||||
val (index, value) = x.vector.toSeq.unzip
|
||||
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
}
|
||||
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
|
||||
val trainMat = new DMatrix(dataIter, null)
|
||||
val watches = List("train" -> trainMat).toMap
|
||||
val round = 2
|
||||
val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
||||
earlyStoppingRound = numEarlyStoppingRounds)
|
||||
Communicator.shutdown()
|
||||
collector.collect(new XGBoostModel(booster))
|
||||
}
|
||||
}
|
||||
|
||||
val logger = LogFactory.getLog(this.getClass)
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
|
||||
new XGBoostModel(
|
||||
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a xgboost model with link.
|
||||
*
|
||||
* @param dtrain The training data.
|
||||
* @param params The parameters to XGBoost.
|
||||
* @param round Number of rounds to train.
|
||||
*/
|
||||
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
|
||||
XGBoostModel = {
|
||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||
if (tracker.start(0L)) {
|
||||
dtrain
|
||||
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
|
||||
.reduce((x, y) => x).collect().head
|
||||
} else {
|
||||
throw new Error("Tracker cannot be started")
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
|
||||
import org.apache.flink.api.scala.{DataSet, _}
|
||||
import org.apache.flink.ml.math.Vector
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
class XGBoostModel (booster: Booster) extends Serializable {
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModelAsHadoopFile(modelPath: String): Unit = {
|
||||
booster.saveModel(FileSystem
|
||||
.get(new Configuration)
|
||||
.create(new Path(modelPath)))
|
||||
}
|
||||
|
||||
/**
|
||||
* predict with the given DMatrix
|
||||
* @param testSet the local test set represented as DMatrix
|
||||
* @return prediction result
|
||||
*/
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(testSet, true, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict given vector dataset.
|
||||
*
|
||||
* @param data The dataset to be predicted.
|
||||
* @return The prediction result.
|
||||
*/
|
||||
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
|
||||
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
|
||||
(it: Iterator[Vector]) => {
|
||||
val mapper = (x: Vector) => {
|
||||
val (index, value) = x.toSeq.unzip
|
||||
LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
}
|
||||
val dataIter = for (x <- it) yield mapper(x)
|
||||
val dmat = new DMatrix(dataIter, null)
|
||||
this.booster.predict(dmat)
|
||||
}
|
||||
data.mapPartition(predictMap)
|
||||
}
|
||||
}
|
||||
@@ -38,22 +38,10 @@
|
||||
<version>4.13.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-actor_${scala.binary.version}</artifactId>
|
||||
<version>2.6.20</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
|
||||
<version>2.6.20</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
<version>3.0.5</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
|
||||
@@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import ai.rapids.cudf.Table
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||
|
||||
class QuantileDMatrixSuite extends FunSuite {
|
||||
class QuantileDMatrixSuite extends AnyFunSuite {
|
||||
|
||||
test("QuantileDMatrix test") {
|
||||
|
||||
|
||||
@@ -44,13 +44,6 @@
|
||||
<version>${spark.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.rapids</groupId>
|
||||
<artifactId>cudf</artifactId>
|
||||
<version>${cudf.version}</version>
|
||||
<classifier>${cudf.classifier}</classifier>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.nvidia</groupId>
|
||||
<artifactId>rapids-4-spark_${scala.binary.version}</artifactId>
|
||||
|
||||
@@ -20,14 +20,15 @@ import java.nio.file.{Files, Path}
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.util.{Locale, TimeZone}
|
||||
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.{GpuTestUtils, SparkConf}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
|
||||
trait GpuTestSuite extends FunSuite with TmpFolderSuite {
|
||||
trait GpuTestSuite extends AnyFunSuite with TmpFolderSuite {
|
||||
import SparkSessionHolder.withSparkSession
|
||||
|
||||
protected def getResourcePath(resource: String): String = {
|
||||
@@ -200,7 +201,7 @@ trait GpuTestSuite extends FunSuite with TmpFolderSuite {
|
||||
|
||||
}
|
||||
|
||||
trait TmpFolderSuite extends BeforeAndAfterAll { self: FunSuite =>
|
||||
trait TmpFolderSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
|
||||
protected var tempDir: Path = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -22,7 +22,6 @@ import java.util.ServiceLoader
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Communicator
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
@@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||
@@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
private var batchCnt = 0
|
||||
|
||||
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Communicator.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
@@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
|
||||
override def hasNext: Boolean = batchIterImpl.hasNext
|
||||
|
||||
override def next(): Row = {
|
||||
val ret = batchIterImpl.next()
|
||||
if (!batchIterImpl.hasNext) {
|
||||
Communicator.shutdown()
|
||||
}
|
||||
ret
|
||||
}
|
||||
override def next(): Row = batchIterImpl.next()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -23,7 +23,6 @@ import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
@@ -44,21 +43,16 @@ import org.apache.spark.sql.SparkSession
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (in milliseconds)
|
||||
* (supported by "scala" implementation only.)
|
||||
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||
* in Scala without Python components, and with full support of timeouts.
|
||||
* The Scala implementation is currently experimental, use at your own risk.
|
||||
*
|
||||
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
|
||||
* This is only needed if the host IP cannot be automatically guessed.
|
||||
* @param pythonExec The python executed path for Rabit Tracker,
|
||||
* which is only used for python implementation.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
|
||||
case class TrackerConf(workerConnectionTimeout: Long,
|
||||
hostIp: String = "", pythonExec: String = "")
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
def apply(): TrackerConf = TrackerConf(0L)
|
||||
}
|
||||
|
||||
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
||||
@@ -349,11 +343,9 @@ object XGBoost extends Serializable {
|
||||
|
||||
/** visiable for testing */
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
||||
case "scala" => new RabitTracker(nWorkers)
|
||||
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
|
||||
case _ => new PyRabitTracker(nWorkers)
|
||||
}
|
||||
val tracker: IRabitTracker = new PyRabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
|
||||
)
|
||||
tracker
|
||||
}
|
||||
|
||||
|
||||
@@ -22,11 +22,10 @@ import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
|
||||
val classifier = new XGBoostClassifier(paramMap)
|
||||
@@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
val paramMap = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp))
|
||||
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
||||
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
tracker match {
|
||||
@@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
val paramMap1 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
|
||||
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
|
||||
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
||||
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
||||
tracker1 match {
|
||||
@@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
val paramMap2 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
|
||||
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
||||
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
||||
tracker2 match {
|
||||
@@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
}
|
||||
}
|
||||
|
||||
test("training with Scala-implemented Rabit tracker") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
test("test Communicator allreduce to validate Scala-implemented Rabit tracker") {
|
||||
val vectorLength = 100
|
||||
val rdd = sc.parallelize(
|
||||
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
|
||||
|
||||
val rawData = rdd.mapPartitions { iter =>
|
||||
Iterator(iter.toArray)
|
||||
}.collect()
|
||||
|
||||
val maxVec = (0 until vectorLength).toArray.map { j =>
|
||||
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
|
||||
}
|
||||
|
||||
val allReduceResults = rdd.mapPartitions { iter =>
|
||||
Communicator.init(trackerEnvs)
|
||||
val arr = iter.toArray
|
||||
val results = Communicator.allReduce(arr, Communicator.OpType.MAX)
|
||||
Communicator.shutdown()
|
||||
Iterator(results)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
allReduceResults.foreachPartition(() => _)
|
||||
val byPartitionResults = allReduceResults.collect()
|
||||
assert(byPartitionResults(0).length == vectorLength)
|
||||
collectedAllReduceResults.put(byPartitionResults(0))
|
||||
}
|
||||
}
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == 0)
|
||||
sparkThread.join()
|
||||
|
||||
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||
}
|
||||
|
||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||
/*
|
||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||
@@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
assert(tracker.waitFor(0) != 0)
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Communicator.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Communicator.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's workerConnectionTimeout") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(500)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
val index = iter.next()
|
||||
// simulate that the first worker cannot connect to tracker due to network issues.
|
||||
if (index != 1) {
|
||||
Communicator.init(trackerEnvs)
|
||||
Thread.sleep(1000)
|
||||
Communicator.shutdown()
|
||||
}
|
||||
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
// should fail due to connection timeout
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
}
|
||||
|
||||
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
|
||||
" multiple times (ISSUE-4406)") {
|
||||
val paramMap = Map(
|
||||
|
||||
@@ -17,13 +17,13 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
test("perform deterministic partitioning when checkpointInternal and" +
|
||||
" checkpointPath is set (Classifier)") {
|
||||
|
||||
@@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
|
||||
Map[String, Any] = {
|
||||
|
||||
@@ -18,12 +18,12 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.Partitioner
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class FeatureSizeValidatingSuite extends FunSuite with PerTest {
|
||||
class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
test("transform throwing exception if feature size of dataset is greater than model's") {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
|
||||
@@ -19,12 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
|
||||
class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||
class MissingValueHandlingSuite extends AnyFunSuite with PerTest {
|
||||
test("dense vectors containing missing value") {
|
||||
def buildDenseDataFrame(): DataFrame = {
|
||||
val numRows = 100
|
||||
|
||||
@@ -16,12 +16,13 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
|
||||
class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||
|
||||
@@ -22,13 +22,14 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import scala.math.min
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.commons.io.IOUtils
|
||||
|
||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
|
||||
|
||||
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ import scala.util.Random
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class PersistenceSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
|
||||
val eval = new EvalError()
|
||||
|
||||
@@ -19,9 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.nio.file.{Files, Path}
|
||||
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
trait TmpFolderPerSuite extends BeforeAndAfterAll { self: FunSuite =>
|
||||
trait TmpFolderPerSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
|
||||
protected var tempDir: Path = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
|
||||
@@ -22,13 +22,13 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.apache.commons.io.IOUtils
|
||||
|
||||
import org.apache.spark.Partitioner
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuite {
|
||||
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||
|
||||
protected val treeMethod: String = "auto"
|
||||
|
||||
|
||||
@@ -21,11 +21,11 @@ import ml.dmlc.xgboost4j.scala.Booster
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
|
||||
class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest {
|
||||
class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
val predictionErrorMin = 0.00001f
|
||||
val maxFailure = 2;
|
||||
|
||||
|
||||
@@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
class XGBoostConfigureSuite extends FunSuite with PerTest {
|
||||
class XGBoostConfigureSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
|
||||
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
|
||||
@@ -22,12 +22,12 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
|
||||
import org.apache.spark.{SparkException, TaskContext}
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.sql.functions.lit
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
test("distributed training with the specified worker number") {
|
||||
val trainingRDD = sc.parallelize(Classification.train)
|
||||
|
||||
@@ -23,11 +23,11 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
|
||||
class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||
protected val treeMethod: String = "auto"
|
||||
|
||||
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
||||
|
||||
@@ -69,7 +69,7 @@ pom_template = """
|
||||
<dependency>
|
||||
<groupId>org.scalactic</groupId>
|
||||
<artifactId>scalactic_${{scala.binary.version}}</artifactId>
|
||||
<version>3.0.8</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
|
||||
@@ -31,22 +31,10 @@
|
||||
<version>4.13.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-actor_${scala.binary.version}</artifactId>
|
||||
<version>2.6.20</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
|
||||
<version>2.6.20</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
<version>3.0.5</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit
|
||||
|
||||
import java.net.{InetAddress, InetSocketAddress}
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import akka.pattern.ask
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, TrackerProperties}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.{Await, Future}
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
/**
|
||||
* Scala implementation of the Rabit tracker interface without Python dependency.
|
||||
* The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
|
||||
* (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
|
||||
* failures.
|
||||
*
|
||||
* Note that this implementation is currently experimental, and should be used at your own risk.
|
||||
*
|
||||
* Example usage:
|
||||
* {{{
|
||||
* import scala.concurrent.duration._
|
||||
*
|
||||
* val tracker = new RabitTracker(32)
|
||||
* // allow up to 10 minutes for all workers to connect to the tracker.
|
||||
* tracker.start(10 minutes)
|
||||
*
|
||||
* /* ...
|
||||
* launching workers in parallel
|
||||
* ...
|
||||
* */
|
||||
*
|
||||
* // wait for worker execution up to 6 hours.
|
||||
* // providing a finite timeout prevents a long-running task from hanging forever in
|
||||
* // catastrophic events, like the loss of an executor during model training.
|
||||
* tracker.waitFor(6 hours)
|
||||
* }}}
|
||||
*
|
||||
* @param numWorkers Number of distributed workers from which the tracker expects connections.
|
||||
* @param port The minimum port number that the tracker binds to.
|
||||
* If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
|
||||
* @param maxPortTrials The maximum number of trials of socket binding, by sequentially
|
||||
* increasing the port number.
|
||||
*/
|
||||
private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
|
||||
maxPortTrials: Int = 1000)
|
||||
extends IRabitTracker {
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
|
||||
|
||||
val system = ActorSystem.create("RabitTracker")
|
||||
val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
|
||||
implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
|
||||
private[this] val tcpBindingTimeout: Duration = 1 minute
|
||||
|
||||
var workerEnvs: Map[String, String] = Map.empty
|
||||
|
||||
override def uncaughtException(t: Thread, e: Throwable): Unit = {
|
||||
handler ? RabitTrackerHandler.InterruptTracker(e)
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the Rabit tracker.
|
||||
*
|
||||
* @param timeout The timeout for awaiting connections from worker nodes.
|
||||
* Note that when used in Spark applications, because all Spark transformations are
|
||||
* lazily executed, the I/O time for loading RDDs/DataFrames from external sources
|
||||
* (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
|
||||
* If the timeout value is too small, the Rabit tracker will likely timeout before workers
|
||||
* establishing connections to the tracker, due to the overhead of loading data.
|
||||
* Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
|
||||
* running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
|
||||
* @return Boolean flag indicating if the Rabit tracker starts successfully.
|
||||
*/
|
||||
private def start(timeout: Duration): Boolean = {
|
||||
val hostAddress = Option(TrackerProperties.getInstance().getHostIp)
|
||||
.map(InetAddress.getByName).getOrElse(InetAddress.getLocalHost)
|
||||
|
||||
handler ? RabitTrackerHandler.StartTracker(
|
||||
new InetSocketAddress(hostAddress, port.getOrElse(0)), maxPortTrials, timeout)
|
||||
|
||||
// block by waiting for the actor to bind to a port
|
||||
Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
|
||||
.asInstanceOf[Future[Map[String, String]]]) match {
|
||||
case Success(futurePortBound) =>
|
||||
// The success of the Future is contingent on binding to an InetSocketAddress.
|
||||
val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
|
||||
if (isBound) {
|
||||
workerEnvs = Await.result(futurePortBound, 0 nano)
|
||||
}
|
||||
isBound
|
||||
case Failure(ex: Throwable) =>
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the Rabit tracker.
|
||||
*
|
||||
* @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
|
||||
* connections. If a non-positive value is provided, the tracker
|
||||
* waits for incoming worker connections indefinitely.
|
||||
* @return Boolean flag indicating if the Rabit tracker starts successfully.
|
||||
*/
|
||||
def start(connectionTimeoutMillis: Long): Boolean = {
|
||||
if (connectionTimeoutMillis <= 0) {
|
||||
start(Duration.Inf)
|
||||
} else {
|
||||
start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
|
||||
}
|
||||
}
|
||||
|
||||
def stop(): Unit = {
|
||||
system.terminate()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a Map of necessary environment variables to initiate Rabit workers.
|
||||
*
|
||||
* @return HashMap containing tracker information.
|
||||
*/
|
||||
def getWorkerEnvs: java.util.Map[String, String] = {
|
||||
new java.util.HashMap((workerEnvs ++ Map(
|
||||
"DMLC_NUM_WORKER" -> numWorkers.toString,
|
||||
"DMLC_NUM_SERVER" -> "0"
|
||||
)).asJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
|
||||
* This method blocks until timeout or task completion.
|
||||
*
|
||||
* @param atMost the maximum execution time for the workers. By default,
|
||||
* the tracker waits for the workers indefinitely.
|
||||
* @return 0 if the tasks complete successfully, and non-zero otherwise.
|
||||
*/
|
||||
private def waitFor(atMost: Duration): Int = {
|
||||
// request the completion Future from the tracker actor
|
||||
Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
|
||||
.asInstanceOf[Future[Int]]) match {
|
||||
case Success(futureCompleted) =>
|
||||
// wait for all workers to complete synchronously.
|
||||
val statusCode = Try(Await.result(futureCompleted, atMost)) match {
|
||||
case Success(n) if n == numWorkers =>
|
||||
IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
|
||||
case Success(n) if n < numWorkers =>
|
||||
IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
|
||||
case Failure(e) =>
|
||||
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
|
||||
}
|
||||
system.terminate()
|
||||
statusCode
|
||||
case Failure(ex: Throwable) =>
|
||||
system.terminate()
|
||||
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
|
||||
* This method blocks until timeout or task completion.
|
||||
*
|
||||
* @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
|
||||
* non-positive number is given, the tracker waits indefinitely.
|
||||
* @return 0 if the tasks complete successfully, and non-zero otherwise
|
||||
*/
|
||||
def waitFor(atMostMillis: Long): Int = {
|
||||
if (atMostMillis <= 0) {
|
||||
waitFor(Duration.Inf)
|
||||
} else {
|
||||
waitFor(Duration.fromNanos(atMostMillis * 1e6))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,361 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.handler
|
||||
|
||||
import java.net.InetSocketAddress
|
||||
import java.util.UUID
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.collection.mutable
|
||||
import scala.concurrent.{Promise, TimeoutException}
|
||||
import akka.io.{IO, Tcp}
|
||||
import akka.actor._
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
|
||||
|
||||
import scala.util.{Failure, Random, Success, Try}
|
||||
|
||||
/** The Akka actor for handling and coordinating Rabit worker connections.
|
||||
* This is the main actor for handling socket connections, interacting with the synchronous
|
||||
* tracker interface, and resolving tree/ring/parent dependencies between workers.
|
||||
*
|
||||
* @param numWorkers Number of workers to track.
|
||||
*/
|
||||
private[scala] class RabitTrackerHandler(numWorkers: Int)
|
||||
extends Actor with ActorLogging {
|
||||
|
||||
import context.system
|
||||
import RabitWorkerHandler._
|
||||
import RabitTrackerHandler._
|
||||
|
||||
private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
|
||||
private[this] val promisedShutdownWorkers = Promise[Int]()
|
||||
private[this] val tcpManager = IO(Tcp)
|
||||
|
||||
// resolves worker connection dependency.
|
||||
val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
|
||||
|
||||
// workers that have sent "shutdown" signal
|
||||
private[this] val shutdownWorkers = mutable.Set.empty[Int]
|
||||
private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
|
||||
private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
|
||||
private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
|
||||
private[this] var maxPortTrials = 0
|
||||
private[this] var workerConnectionTimeout: Duration = Duration.Inf
|
||||
private[this] var portTrials = 0
|
||||
private[this] val startedWorkers = mutable.Set.empty[Int]
|
||||
|
||||
val linkMap = new LinkMap(numWorkers)
|
||||
|
||||
def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
|
||||
rank match {
|
||||
case r if r >= 0 => Some(r)
|
||||
case _ =>
|
||||
jobId match {
|
||||
case "NULL" => None
|
||||
case jid => jobToRankMap.get(jid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
|
||||
* by the RabitWorkerHandler.
|
||||
*
|
||||
* @param event Generic Tcp.Event
|
||||
*/
|
||||
private def handleTcpEvents(event: Tcp.Event): Unit = event match {
|
||||
case Tcp.Bound(local) =>
|
||||
// expect all workers to connect within timeout
|
||||
log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
|
||||
log.info(s"Worker connection timeout is $workerConnectionTimeout.")
|
||||
|
||||
context.setReceiveTimeout(workerConnectionTimeout)
|
||||
promisedWorkerEnvs.success(Map(
|
||||
"DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
|
||||
"DMLC_TRACKER_PORT" -> local.getPort.toString,
|
||||
// not required because the world size will be communicated to the
|
||||
// worker node after the rank is assigned.
|
||||
"rabit_world_size" -> numWorkers.toString
|
||||
))
|
||||
|
||||
case Tcp.CommandFailed(cmd: Tcp.Bind) =>
|
||||
if (portTrials < maxPortTrials) {
|
||||
portTrials += 1
|
||||
tcpManager ! Tcp.Bind(self,
|
||||
new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
|
||||
backlog = 256)
|
||||
}
|
||||
|
||||
case Tcp.Connected(remote, local) =>
|
||||
log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
|
||||
// revoke timeout if all workers have connected.
|
||||
val workerHandler = context.actorOf(RabitWorkerHandler.props(
|
||||
remote.getAddress.getHostAddress, numWorkers, self, sender()
|
||||
), s"ConnectionHandler-${UUID.randomUUID().toString}")
|
||||
val connection = sender()
|
||||
connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
|
||||
|
||||
actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
|
||||
* to interact with the tracker interface.
|
||||
*
|
||||
* @param trackerMsg control messages sent by RabitTracker class.
|
||||
*/
|
||||
private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
|
||||
trackerMsg match {
|
||||
|
||||
case msg: StartTracker =>
|
||||
maxPortTrials = msg.maxPortTrials
|
||||
workerConnectionTimeout = msg.connectionTimeout
|
||||
|
||||
// if the port number is missing, try binding to a random ephemeral port.
|
||||
if (msg.addr.getPort == 0) {
|
||||
tcpManager ! Tcp.Bind(self,
|
||||
new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
|
||||
backlog = 256)
|
||||
} else {
|
||||
tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
|
||||
}
|
||||
sender() ! true
|
||||
|
||||
case RequestBoundFuture =>
|
||||
sender() ! promisedWorkerEnvs.future
|
||||
|
||||
case RequestCompletionFuture =>
|
||||
sender() ! promisedShutdownWorkers.future
|
||||
|
||||
case InterruptTracker(e) =>
|
||||
log.error(e, "Uncaught exception thrown by worker.")
|
||||
// make sure that waitFor() does not hang indefinitely.
|
||||
promisedShutdownWorkers.failure(e)
|
||||
context.stop(self)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles messages sent by child actors representing connecting Rabit workers, by brokering
|
||||
* messages to the dependency resolver, and processing worker commands.
|
||||
*
|
||||
* @param workerMsg Message sent by RabitWorkerHandler actors.
|
||||
*/
|
||||
private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
|
||||
case req @ RequestAwaitConnWorkers(_, _) =>
|
||||
// since the requester may request to connect to other workers
|
||||
// that have not fully set up, delegate this request to the
|
||||
// dependency resolver which handles the dependencies properly.
|
||||
resolver forward req
|
||||
|
||||
// ---- Rabit worker commands: start/recover/shutdown/print ----
|
||||
case WorkerTrackerPrint(_, _, _, msg) =>
|
||||
log.info(msg.trim)
|
||||
|
||||
case WorkerShutdown(rank, _, _) =>
|
||||
assert(rank >= 0, "Invalid rank.")
|
||||
assert(!shutdownWorkers.contains(rank))
|
||||
shutdownWorkers.add(rank)
|
||||
|
||||
log.info(s"Received shutdown signal from $rank")
|
||||
|
||||
if (shutdownWorkers.size == numWorkers) {
|
||||
promisedShutdownWorkers.success(shutdownWorkers.size)
|
||||
}
|
||||
|
||||
case WorkerRecover(prevRank, worldSize, jobId) =>
|
||||
assert(prevRank >= 0)
|
||||
sender() ! linkMap.assignRank(prevRank)
|
||||
|
||||
case WorkerStart(rank, worldSize, jobId) =>
|
||||
assert(worldSize == numWorkers || worldSize == -1,
|
||||
s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
|
||||
)
|
||||
|
||||
Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
|
||||
case Success(wkRank) =>
|
||||
if (jobId != "NULL") {
|
||||
jobToRankMap.put(jobId, wkRank)
|
||||
}
|
||||
|
||||
val assignedRank = linkMap.assignRank(wkRank)
|
||||
sender() ! assignedRank
|
||||
resolver ! assignedRank
|
||||
|
||||
log.info("Received start signal from " +
|
||||
s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
|
||||
|
||||
case Failure(ex: IndexOutOfBoundsException) =>
|
||||
// More than worldSize workers have connected, likely due to executor loss.
|
||||
// Since Rabit currently does not support crash recovery (because the Allreduce results
|
||||
// are not cached by the tracker, and because existing workers cannot reestablish
|
||||
// connections to newly spawned executor/worker), the most reasonble action here is to
|
||||
// interrupt the tracker immediate with failure state.
|
||||
log.error("Received invalid start signal from " +
|
||||
s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
|
||||
)
|
||||
promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
|
||||
" received from worker, likely due to executor loss."))
|
||||
|
||||
case Failure(ex) =>
|
||||
log.error(ex, "Unexpected error")
|
||||
promisedShutdownWorkers.failure(ex)
|
||||
}
|
||||
|
||||
|
||||
// ---- Dependency resolving related messages ----
|
||||
case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
|
||||
log.info(s"Worker $host (rank: $rank) has started.")
|
||||
resolver forward msg
|
||||
|
||||
startedWorkers.add(rank)
|
||||
if (startedWorkers.size == numWorkers) {
|
||||
log.info("All workers have started.")
|
||||
}
|
||||
|
||||
case req @ DropFromWaitingList(_) =>
|
||||
// all peer workers in dependency link map have connected;
|
||||
// forward message to resolver to update dependencies.
|
||||
resolver forward req
|
||||
|
||||
case _ =>
|
||||
}
|
||||
|
||||
def receive: Actor.Receive = {
|
||||
case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
|
||||
case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
|
||||
case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
|
||||
|
||||
case akka.actor.ReceiveTimeout =>
|
||||
if (startedWorkers.size < numWorkers) {
|
||||
promisedShutdownWorkers.failure(
|
||||
new TimeoutException("Timed out waiting for workers to connect: " +
|
||||
s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
|
||||
)
|
||||
context.stop(self)
|
||||
}
|
||||
|
||||
context.setReceiveTimeout(Duration.Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the dependency between nodes as they connect to the tracker.
|
||||
* The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
|
||||
* and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
|
||||
* connections by workers, this dependency constraint assumes that a worker node connects first
|
||||
* is likely to finish setup first.
|
||||
*/
|
||||
private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
|
||||
import RabitWorkerHandler._
|
||||
|
||||
context.watch(handler)
|
||||
|
||||
case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
|
||||
|
||||
// worker nodes that have connected, but have not send WorkerStarted message.
|
||||
private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
|
||||
private val startedWorkers = mutable.Set.empty[Int]
|
||||
// worker nodes that have started, and await for connections.
|
||||
private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
|
||||
private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
|
||||
|
||||
def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
|
||||
val connSet = awaitConnWorkers.toMap
|
||||
.filterKeys(k => linkSet.contains(k))
|
||||
AwaitingConnections(connSet, linkSet.size - connSet.size)
|
||||
}
|
||||
|
||||
def receive: Actor.Receive = {
|
||||
// a copy of the AssignedRank message that is also sent to the worker
|
||||
case AssignedRank(rank, tree_neighbors, ring, parent) =>
|
||||
// the workers that the worker of given `rank` depends on:
|
||||
// worker of rank K only depends on workers with rank smaller than K.
|
||||
val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
|
||||
.filter{ r => r != -1 && r < rank}
|
||||
|
||||
log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
|
||||
dependencyMap.put(rank, dependentWorkers)
|
||||
|
||||
case RequestAwaitConnWorkers(rank, toConnectSet) =>
|
||||
val promise = Promise[AwaitingConnections]()
|
||||
|
||||
assert(dependencyMap.contains(rank))
|
||||
|
||||
val updatedDependency = dependencyMap(rank) diff startedWorkers
|
||||
if (updatedDependency.isEmpty) {
|
||||
// all dependencies are satisfied
|
||||
log.debug(s"Rank $rank has all dependencies satisfied.")
|
||||
promise.success(awaitingWorkers(toConnectSet))
|
||||
} else {
|
||||
log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
|
||||
// promise is pending fulfillment due to unresolved dependency
|
||||
pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
|
||||
}
|
||||
|
||||
sender() ! promise.future
|
||||
|
||||
case WorkerStarted(_, started, awaitingAcceptance) =>
|
||||
startedWorkers.add(started)
|
||||
if (awaitingAcceptance > 0) {
|
||||
awaitConnWorkers.put(started, sender())
|
||||
}
|
||||
|
||||
// remove the started rank from all dependencies.
|
||||
dependencyMap.remove(started)
|
||||
dependencyMap.foreach { case (r, dset) =>
|
||||
val updatedDependency = dset diff startedWorkers
|
||||
// fulfill the future if all dependencies are met (started.)
|
||||
if (updatedDependency.isEmpty) {
|
||||
log.debug(s"Rank $r has all dependencies satisfied.")
|
||||
pendingFulfillment.remove(r).map{
|
||||
case Fulfillment(toConnectSet, promise) =>
|
||||
promise.success(awaitingWorkers(toConnectSet))
|
||||
}
|
||||
}
|
||||
|
||||
dependencyMap.update(r, updatedDependency)
|
||||
}
|
||||
|
||||
case DropFromWaitingList(rank) =>
|
||||
assert(awaitConnWorkers.remove(rank).isDefined)
|
||||
|
||||
case Terminated(ref) =>
|
||||
if (ref.equals(handler)) {
|
||||
context.stop(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] object RabitTrackerHandler {
|
||||
// Messages sent by RabitTracker to this RabitTrackerHandler actor
|
||||
trait TrackerControlMessage
|
||||
case object RequestCompletionFuture extends TrackerControlMessage
|
||||
case object RequestBoundFuture extends TrackerControlMessage
|
||||
// Start the Rabit tracker at given socket address awaiting worker connections.
|
||||
// All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
|
||||
// shut down due to timeout.
|
||||
case class StartTracker(addr: InetSocketAddress,
|
||||
maxPortTrials: Int,
|
||||
connectionTimeout: Duration) extends TrackerControlMessage
|
||||
// To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
|
||||
// driver for the distributed training.
|
||||
case class InterruptTracker(e: Throwable) extends TrackerControlMessage
|
||||
|
||||
def props(numWorkers: Int): Props =
|
||||
Props(new RabitTrackerHandler(numWorkers))
|
||||
}
|
||||
@@ -1,467 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.handler
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
import akka.io.Tcp
|
||||
import akka.actor._
|
||||
import akka.util.ByteString
|
||||
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
|
||||
|
||||
import scala.concurrent.{Await, Future}
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.Try
|
||||
|
||||
/**
|
||||
* Actor to handle socket communication from worker node.
|
||||
* To handle fragmentation in received data, this class acts like a FSM
|
||||
* (finite-state machine) to keep track of the internal states.
|
||||
*
|
||||
* @param host IP address of the remote worker
|
||||
* @param worldSize number of total workers
|
||||
* @param tracker the RabitTrackerHandler actor reference
|
||||
*/
|
||||
private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
|
||||
connection: ActorRef)
|
||||
extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
|
||||
with ActorLogging with Stash {
|
||||
|
||||
import RabitWorkerHandler._
|
||||
import RabitTrackerHelpers._
|
||||
|
||||
private[this] var rank: Int = 0
|
||||
private[this] var port: Int = 0
|
||||
|
||||
// indicate if the connection is transient (like "print" or "shutdown")
|
||||
private[this] var transient: Boolean = false
|
||||
private[this] var peerClosed: Boolean = false
|
||||
|
||||
// number of workers pending acceptance of current worker
|
||||
private[this] var awaitingAcceptance: Int = 0
|
||||
private[this] var neighboringWorkers = Set.empty[Int]
|
||||
|
||||
// TODO: use a single memory allocation to host all buffers,
|
||||
// including the transient ones for writing.
|
||||
private[this] val readBuffer = ByteBuffer.allocate(4096)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
// in case the received message is longer than needed,
|
||||
// stash the spilled over part in this buffer, and send
|
||||
// to self when transition occurs.
|
||||
private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
// when setup is complete, need to notify peer handlers
|
||||
// to reduce the awaiting-connection counter.
|
||||
private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
|
||||
|
||||
private def resetBuffers(): Unit = {
|
||||
readBuffer.clear()
|
||||
if (spillOverBuffer.position() > 0) {
|
||||
spillOverBuffer.flip()
|
||||
self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
|
||||
spillOverBuffer.clear()
|
||||
}
|
||||
}
|
||||
|
||||
private def stashSpillOver(buf: ByteBuffer): Unit = {
|
||||
if (buf.remaining() > 0) spillOverBuffer.put(buf)
|
||||
}
|
||||
|
||||
def getNeighboringWorkers: Set[Int] = neighboringWorkers
|
||||
|
||||
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
|
||||
val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
|
||||
readBuffer.flip()
|
||||
|
||||
val rank = readBuffer.getInt()
|
||||
val worldSize = readBuffer.getInt()
|
||||
val jobId = readBuffer.getString
|
||||
|
||||
val command = readBuffer.getString
|
||||
val trackerCommand = command match {
|
||||
case "start" => WorkerStart(rank, worldSize, jobId)
|
||||
case "shutdown" =>
|
||||
transient = true
|
||||
WorkerShutdown(rank, worldSize, jobId)
|
||||
case "recover" =>
|
||||
require(rank >= 0, "Invalid rank for recovering worker.")
|
||||
WorkerRecover(rank, worldSize, jobId)
|
||||
case "print" =>
|
||||
transient = true
|
||||
WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString)
|
||||
}
|
||||
|
||||
stashSpillOver(readBuffer)
|
||||
trackerCommand
|
||||
}
|
||||
|
||||
startWith(AwaitingHandshake, DataStruct())
|
||||
|
||||
when(AwaitingHandshake) {
|
||||
case Event(Tcp.Received(magic), _) =>
|
||||
assert(magic.length == 4)
|
||||
val purportedMagic = magic.asNativeOrderByteBuffer.getInt
|
||||
assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
|
||||
|
||||
// echo back the magic number
|
||||
connection ! Tcp.Write(magic)
|
||||
goto(AwaitingCommand) using StructTrackerCommand
|
||||
}
|
||||
|
||||
when(AwaitingCommand) {
|
||||
case Event(Tcp.Received(bytes), validator) =>
|
||||
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
|
||||
if (validator.verify(readBuffer)) {
|
||||
Try(decodeCommand(readBuffer)) match {
|
||||
case scala.util.Success(decodedCommand) =>
|
||||
tracker ! decodedCommand
|
||||
case scala.util.Failure(th: java.nio.BufferUnderflowException) =>
|
||||
// BufferUnderflowException would occur if the message to print has not arrived yet.
|
||||
// Do nothing, wait for next Tcp.Received event
|
||||
case scala.util.Failure(th: Throwable) => throw th
|
||||
}
|
||||
}
|
||||
|
||||
stay
|
||||
// when rank for a worker is assigned, send encoded rank information
|
||||
// back to worker over Tcp socket.
|
||||
case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
|
||||
log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
|
||||
|
||||
rank = assignedRank
|
||||
// ranks from the ring
|
||||
val ringRanks = List(
|
||||
// ringPrev
|
||||
if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
|
||||
// ringNext
|
||||
if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
|
||||
)
|
||||
|
||||
// update the set of all linked workers to current worker.
|
||||
neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
|
||||
|
||||
connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
|
||||
// to prevent reading before state transition
|
||||
connection ! Tcp.SuspendReading
|
||||
goto(BuildingLinkMap) using StructNodes
|
||||
}
|
||||
|
||||
when(BuildingLinkMap) {
|
||||
case Event(Tcp.Received(bytes), validator) =>
|
||||
bytes.asByteBuffers.foreach { buf =>
|
||||
readBuffer.put(buf)
|
||||
}
|
||||
|
||||
if (validator.verify(readBuffer)) {
|
||||
readBuffer.flip()
|
||||
// for a freshly started worker, numConnected should be 0.
|
||||
val numConnected = readBuffer.getInt()
|
||||
val toConnectSet = neighboringWorkers.diff(
|
||||
(0 until numConnected).map { index => readBuffer.getInt() }.toSet)
|
||||
|
||||
// check which workers are currently awaiting connections
|
||||
tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
|
||||
}
|
||||
stay
|
||||
|
||||
// got a Future from the tracker (resolver) about workers that are
|
||||
// currently awaiting connections (particularly from this node.)
|
||||
case Event(future: Future[_], _) =>
|
||||
// blocks execution until all dependencies for current worker is resolved.
|
||||
Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
|
||||
// numNotReachable is the number of workers that currently
|
||||
// cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
|
||||
// connections from those currently non-reachable nodes in the future.
|
||||
case AwaitingConnections(waitConnNodes, numNotReachable) =>
|
||||
log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
|
||||
val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
|
||||
buf.putInt(waitConnNodes.size).putInt(numNotReachable)
|
||||
buf.flip()
|
||||
|
||||
// cache this message until the final state (SetupComplete)
|
||||
pendingAcknowledgement = Some(AcknowledgeAcceptance(
|
||||
waitConnNodes, numNotReachable))
|
||||
|
||||
connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
|
||||
if (waitConnNodes.isEmpty) {
|
||||
connection ! Tcp.SuspendReading
|
||||
goto(AwaitingErrorCount)
|
||||
}
|
||||
else {
|
||||
waitConnNodes.foreach { case (peerRank, peerRef) =>
|
||||
peerRef ! RequestWorkerHostPort
|
||||
}
|
||||
|
||||
// a countdown for DivulgedHostPort messages.
|
||||
stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
|
||||
}
|
||||
}
|
||||
|
||||
case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
|
||||
val hostBytes = peerHost.getBytes()
|
||||
val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
buffer.putInt(peerHost.length).put(hostBytes)
|
||||
.putInt(peerPort).putInt(peerRank)
|
||||
|
||||
buffer.flip()
|
||||
connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
|
||||
|
||||
if (data.counter == 0) {
|
||||
// to prevent reading before state transition
|
||||
connection ! Tcp.SuspendReading
|
||||
goto(AwaitingErrorCount)
|
||||
}
|
||||
else {
|
||||
stay using data.decrement()
|
||||
}
|
||||
}
|
||||
|
||||
when(AwaitingErrorCount) {
|
||||
case Event(Tcp.Received(numErrors), _) =>
|
||||
val buf = numErrors.asNativeOrderByteBuffer
|
||||
|
||||
buf.getInt match {
|
||||
case 0 =>
|
||||
stashSpillOver(buf)
|
||||
goto(AwaitingPortNumber)
|
||||
case _ =>
|
||||
stashSpillOver(buf)
|
||||
goto(BuildingLinkMap) using StructNodes
|
||||
}
|
||||
}
|
||||
|
||||
when(AwaitingPortNumber) {
|
||||
case Event(Tcp.Received(assignedPort), _) =>
|
||||
assert(assignedPort.length == 4)
|
||||
port = assignedPort.asNativeOrderByteBuffer.getInt
|
||||
log.debug(s"Rank $rank listening @ $host:$port")
|
||||
// wait until the worker closes connection.
|
||||
if (peerClosed) goto(SetupComplete) else stay
|
||||
|
||||
case Event(Tcp.PeerClosed, _) =>
|
||||
peerClosed = true
|
||||
if (port == 0) stay else goto(SetupComplete)
|
||||
}
|
||||
|
||||
when(SetupComplete) {
|
||||
case Event(ReduceWaitCount(count: Int), _) =>
|
||||
awaitingAcceptance -= count
|
||||
// check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
|
||||
if (awaitingAcceptance == 0 && peerClosed) {
|
||||
tracker ! DropFromWaitingList(rank)
|
||||
// no longer needed.
|
||||
context.stop(self)
|
||||
}
|
||||
stay
|
||||
|
||||
case Event(AcknowledgeAcceptance(peers, numBad), _) =>
|
||||
awaitingAcceptance = numBad
|
||||
tracker ! WorkerStarted(host, rank, awaitingAcceptance)
|
||||
peers.values.foreach { peer =>
|
||||
peer ! ReduceWaitCount(1)
|
||||
}
|
||||
|
||||
if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
|
||||
|
||||
stay
|
||||
|
||||
// can only divulge the complete host and port information
|
||||
// when this worker is declared fully connected (otherwise
|
||||
// port information is still missing.)
|
||||
case Event(RequestWorkerHostPort, _) =>
|
||||
sender() ! DivulgedWorkerHostPort(rank, host, port)
|
||||
stay
|
||||
}
|
||||
|
||||
onTransition {
|
||||
// reset buffer when state transitions as data becomes stale
|
||||
case _ -> SetupComplete =>
|
||||
connection ! Tcp.ResumeReading
|
||||
resetBuffers()
|
||||
if (pendingAcknowledgement.isDefined) {
|
||||
self ! pendingAcknowledgement.get
|
||||
}
|
||||
case _ =>
|
||||
connection ! Tcp.ResumeReading
|
||||
resetBuffers()
|
||||
}
|
||||
|
||||
// default message handler
|
||||
whenUnhandled {
|
||||
case Event(Tcp.PeerClosed, _) =>
|
||||
peerClosed = true
|
||||
if (transient) context.stop(self)
|
||||
stay
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] object RabitWorkerHandler {
|
||||
val MAGIC_NUMBER = 0xff99
|
||||
|
||||
// Finite states of this actor, which acts like a FSM.
|
||||
// The following states are defined in order as the FSM progresses.
|
||||
sealed trait State
|
||||
|
||||
// [1] Initial state, awaiting worker to send magic number per protocol.
|
||||
case object AwaitingHandshake extends State
|
||||
// [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
|
||||
case object AwaitingCommand extends State
|
||||
// [3] Brokers connections between workers per ring/tree/parent link map.
|
||||
case object BuildingLinkMap extends State
|
||||
// [4] A transient state in which the worker reports the number of errors in establishing
|
||||
// connections to other peer workers. If no errors, transition to next state.
|
||||
case object AwaitingErrorCount extends State
|
||||
// [5] Awaiting the worker to report its port number for accepting connections from peer workers.
|
||||
// This port number information is later forwarded to linked workers.
|
||||
case object AwaitingPortNumber extends State
|
||||
// [6] Final state after completing the setup with the connecting worker. At this stage, the
|
||||
// worker will have closed the Tcp connection. The actor remains alive to handle messages from
|
||||
// peer actors representing workers with pending setups.
|
||||
case object SetupComplete extends State
|
||||
|
||||
sealed trait DataField
|
||||
case object IntField extends DataField
|
||||
// an integer preceding the actual string
|
||||
case object StringField extends DataField
|
||||
case object IntSeqField extends DataField
|
||||
|
||||
object DataStruct {
|
||||
def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
|
||||
}
|
||||
|
||||
// Internal data pertaining to individual state, used to verify the validity of packets sent by
|
||||
// workers.
|
||||
case class DataStruct(fields: Seq[DataField], counter: Int) {
|
||||
/**
|
||||
* Validate whether the provided buffer is complete (i.e., contains
|
||||
* all data fields specified for this DataStruct.)
|
||||
*
|
||||
* @param buf a byte buffer containing received data.
|
||||
*/
|
||||
def verify(buf: ByteBuffer): Boolean = {
|
||||
if (fields.isEmpty) return true
|
||||
|
||||
val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
|
||||
dupBuf.flip()
|
||||
|
||||
Try(fields.foldLeft(true) {
|
||||
case (complete, field) =>
|
||||
val remBytes = dupBuf.remaining()
|
||||
complete && (remBytes > 0) && (remBytes >= (field match {
|
||||
case IntField =>
|
||||
dupBuf.position(dupBuf.position() + 4)
|
||||
4
|
||||
case StringField =>
|
||||
val strLen = dupBuf.getInt
|
||||
dupBuf.position(dupBuf.position() + strLen)
|
||||
4 + strLen
|
||||
case IntSeqField =>
|
||||
val seqLen = dupBuf.getInt
|
||||
dupBuf.position(dupBuf.position() + seqLen * 4)
|
||||
4 + seqLen * 4
|
||||
}))
|
||||
}).getOrElse(false)
|
||||
}
|
||||
|
||||
def increment(): DataStruct = DataStruct(fields, counter + 1)
|
||||
def decrement(): DataStruct = DataStruct(fields, counter - 1)
|
||||
}
|
||||
|
||||
val StructNodes = DataStruct(List(IntSeqField), 0)
|
||||
val StructTrackerCommand = DataStruct(List(
|
||||
IntField, IntField, StringField, StringField
|
||||
), 0)
|
||||
|
||||
// ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
|
||||
|
||||
// RabitWorkerHandler --> RabitTrackerHandler
|
||||
sealed trait RabitWorkerRequest
|
||||
// RabitWorkerHandler <-- RabitTrackerHandler
|
||||
sealed trait RabitWorkerResponse
|
||||
|
||||
// Representations of decoded worker commands.
|
||||
abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
|
||||
def rank: Int
|
||||
def worldSize: Int
|
||||
def jobId: String
|
||||
|
||||
def encode: ByteString = {
|
||||
val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
|
||||
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
|
||||
.putInt(command.length).put(command.getBytes()).flip()
|
||||
|
||||
ByteString.fromByteBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
|
||||
extends TrackerCommand("start")
|
||||
case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
|
||||
extends TrackerCommand("shutdown")
|
||||
case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
|
||||
extends TrackerCommand("recover")
|
||||
case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
|
||||
extends TrackerCommand("print") {
|
||||
|
||||
override def encode: ByteString = {
|
||||
val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
|
||||
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
|
||||
.putInt(command.length).put(command.getBytes())
|
||||
.putInt(msg.length).put(msg.getBytes()).flip()
|
||||
|
||||
ByteString.fromByteBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// Request to remove the worker of given rank from the list of workers awaiting peer connections.
|
||||
case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
|
||||
// Notify the tracker that the worker of given rank has finished setup and started.
|
||||
case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
|
||||
extends RabitWorkerRequest
|
||||
// Request the set of workers to connect to, according to the LinkMap structure.
|
||||
case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
|
||||
extends RabitWorkerRequest
|
||||
|
||||
// Request, from the tracker, the set of nodes to connect.
|
||||
case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
|
||||
extends RabitWorkerResponse
|
||||
|
||||
// ---- Messages between ConnectionHandler actors ----
|
||||
sealed trait IntraWorkerMessage
|
||||
|
||||
// Notify neighboring workers to decrease the counter of awaiting workers by `count`.
|
||||
case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
|
||||
// Request host and port information from peer ConnectionHandler actors (acting on behave of
|
||||
// connecting workers.) This message will be brokered by RabitTrackerHandler.
|
||||
case object RequestWorkerHostPort extends IntraWorkerMessage
|
||||
// Response to the above request
|
||||
case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
|
||||
// A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
|
||||
case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
|
||||
extends IntraWorkerMessage
|
||||
|
||||
// ---- End of message definitions ----
|
||||
|
||||
def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
|
||||
Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
|
||||
}
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.util
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
/**
|
||||
* The assigned rank to a connecting Rabit worker, along with the information of the ranks of
|
||||
* its linked peer workers, which are critical to perform Allreduce.
|
||||
* When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
|
||||
* client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
|
||||
* with this class as a message, which is later encoded to byte string, and sent over socket
|
||||
* connection to the worker client.
|
||||
*
|
||||
* @param rank assigned rank (ranked by worker connection order: first worker connecting to the
|
||||
* tracker is assigned rank 0, second with rank 1, etc.)
|
||||
* @param neighbors ranks of neighboring workers in a tree map.
|
||||
* @param ring ranks of neighboring workers in a ring map.
|
||||
* @param parent rank of the parent worker.
|
||||
*/
|
||||
private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
|
||||
ring: (Int, Int), parent: Int) {
|
||||
/**
|
||||
* Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
|
||||
* client.
|
||||
* @param worldSize the number of total distributed workers. Must match `numWorkers` used in
|
||||
* LinkMap.
|
||||
* @return a ByteBuffer containing encoded data.
|
||||
*/
|
||||
def toByteBuffer(worldSize: Int): ByteBuffer = {
|
||||
val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
|
||||
buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
|
||||
// neighbors in tree structure
|
||||
neighbors.foreach { n => buffer.putInt(n) }
|
||||
buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
|
||||
buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
|
||||
|
||||
buffer.flip()
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
private[rabit] class LinkMap(numWorkers: Int) {
|
||||
private def getNeighbors(rank: Int): Seq[Int] = {
|
||||
val rank1 = rank + 1
|
||||
Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
|
||||
r >= 0 && r < numWorkers
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a ring structure that tends to share nodes with the tree.
|
||||
*
|
||||
* @param treeMap
|
||||
* @param parentMap
|
||||
* @param rank
|
||||
* @return Seq[Int] instance starting from rank.
|
||||
*/
|
||||
private def constructShareRing(treeMap: Map[Int, Seq[Int]],
|
||||
parentMap: Map[Int, Int],
|
||||
rank: Int = 0): Seq[Int] = {
|
||||
treeMap(rank).toSet - parentMap(rank) match {
|
||||
case emptySet if emptySet.isEmpty =>
|
||||
List(rank)
|
||||
case connectionSet =>
|
||||
connectionSet.zipWithIndex.foldLeft(List(rank)) {
|
||||
case (ringSeq, (v, cnt)) =>
|
||||
val vConnSeq = constructShareRing(treeMap, parentMap, v)
|
||||
vConnSeq match {
|
||||
case vconn if vconn.size == cnt + 1 =>
|
||||
ringSeq ++ vconn.reverse
|
||||
case vconn =>
|
||||
ringSeq ++ vconn
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Construct a ring connection used to recover local data.
|
||||
*
|
||||
* @param treeMap
|
||||
* @param parentMap
|
||||
*/
|
||||
private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
|
||||
assert(parentMap(0) == -1)
|
||||
|
||||
val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
|
||||
assert(sharedRing.length == treeMap.size)
|
||||
|
||||
(0 until numWorkers).map { r =>
|
||||
val rPrev = (r + numWorkers - 1) % numWorkers
|
||||
val rNext = (r + 1) % numWorkers
|
||||
sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
|
||||
}.toMap
|
||||
}
|
||||
|
||||
private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
|
||||
private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
|
||||
private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
|
||||
val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
|
||||
case ((rmap, k), i) =>
|
||||
val kNext = ringMap_(k)._2
|
||||
(rmap ++ Map(kNext -> (i + 1)), kNext)
|
||||
}._1
|
||||
|
||||
val ringMap = ringMap_.map {
|
||||
case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
|
||||
}
|
||||
val treeMap = treeMap_.map {
|
||||
case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
|
||||
}
|
||||
val parentMap = parentMap_.map {
|
||||
case (k, v) if k == 0 =>
|
||||
rMap_(k) -> -1
|
||||
case (k, v) =>
|
||||
rMap_(k) -> rMap_(v)
|
||||
}
|
||||
|
||||
def assignRank(rank: Int): AssignedRank = {
|
||||
AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.util
|
||||
|
||||
import java.nio.{ByteOrder, ByteBuffer}
|
||||
import akka.util.ByteString
|
||||
|
||||
private[rabit] object RabitTrackerHelpers {
|
||||
implicit class ByteStringHelplers(bs: ByteString) {
|
||||
// Java by default uses big endian. Enforce native endian so that
|
||||
// the byte order is consistent with the workers.
|
||||
def asNativeOrderByteBuffer: ByteBuffer = {
|
||||
bs.asByteBuffer.order(ByteOrder.nativeOrder())
|
||||
}
|
||||
}
|
||||
|
||||
implicit class ByteBufferHelpers(buf: ByteBuffer) {
|
||||
def getString: String = {
|
||||
val len = buf.getInt()
|
||||
val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
|
||||
buf.get(stringBuffer.array(), 0, len)
|
||||
new String(stringBuffer.array(), "utf-8")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -30,8 +30,8 @@ import org.junit.Test;
|
||||
* @author hzx
|
||||
*/
|
||||
public class BoosterImplTest {
|
||||
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1";
|
||||
private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1";
|
||||
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
|
||||
private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1&format=libsvm";
|
||||
|
||||
public static class EvalError implements IEvaluation {
|
||||
@Override
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@@ -88,7 +88,7 @@ public class DMatrixTest {
|
||||
public void testCreateFromFile() throws XGBoostError {
|
||||
//create DMatrix from file
|
||||
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
|
||||
DMatrix dmat = new DMatrix(filePath);
|
||||
DMatrix dmat = new DMatrix(filePath + "?format=libsvm");
|
||||
//get label
|
||||
float[] labels = dmat.getLabel();
|
||||
//check length
|
||||
|
||||
@@ -20,12 +20,12 @@ import java.util.Arrays
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
|
||||
class DMatrixSuite extends FunSuite {
|
||||
class DMatrixSuite extends AnyFunSuite {
|
||||
test("create DMatrix from File") {
|
||||
val dmat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val dmat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
// get label
|
||||
val labels: Array[Float] = dmat.getLabel
|
||||
// check length
|
||||
|
||||
@@ -20,11 +20,11 @@ import java.io.{FileOutputStream, FileInputStream, File}
|
||||
|
||||
import junit.framework.TestCase
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
|
||||
class ScalaBoosterImplSuite extends FunSuite {
|
||||
class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
|
||||
private class EvalError extends EvalTrait {
|
||||
|
||||
@@ -95,8 +95,8 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("basic operation of booster") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
val predicts = booster.predict(testMat, true)
|
||||
@@ -106,8 +106,8 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
|
||||
test("save/load model with path") {
|
||||
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val eval = new EvalError
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
// save and load
|
||||
@@ -123,8 +123,8 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("save/load model with stream") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val eval = new EvalError
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
// save and load
|
||||
@@ -139,7 +139,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("cross validation") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val params = List("eta" -> "1.0", "max_depth" -> "3", "silent" -> "1", "nthread" -> "6",
|
||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||
val round = 2
|
||||
@@ -148,8 +148,8 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo depthwise") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "3", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "eval_metric" -> "auc").toMap
|
||||
@@ -158,8 +158,8 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo lossguide") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "3", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "auc").toMap
|
||||
@@ -168,8 +168,8 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo lossguide with max bin") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "3", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||
@@ -179,8 +179,8 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo depthwidth with max depth") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
|
||||
@@ -190,8 +190,8 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo depthwidth with max depth and max bin") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
@@ -201,7 +201,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("test training from existing model in scala") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
@@ -213,8 +213,8 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
|
||||
test("test getting number of features from a booster") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
|
||||
TestCase.assertEquals(booster.getNumFeature, 127)
|
||||
|
||||
@@ -1,255 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
import akka.actor.{ActorRef, ActorSystem}
|
||||
import akka.io.Tcp
|
||||
import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe}
|
||||
import akka.util.ByteString
|
||||
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler
|
||||
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._
|
||||
import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap
|
||||
import org.junit.runner.RunWith
|
||||
import org.scalatest.junit.JUnitRunner
|
||||
import org.scalatest.{FlatSpecLike, Matchers}
|
||||
|
||||
import scala.concurrent.Promise
|
||||
|
||||
object RabitTrackerConnectionHandlerTest {
|
||||
def intSeqToByteString(seq: Seq[Int]): ByteString = {
|
||||
val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder())
|
||||
seq.foreach { i => buf.putInt(i) }
|
||||
buf.flip()
|
||||
ByteString.fromByteBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(classOf[JUnitRunner])
|
||||
class RabitTrackerConnectionHandlerTest
|
||||
extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest"))
|
||||
with FlatSpecLike with Matchers with ImplicitSender {
|
||||
|
||||
import RabitTrackerConnectionHandlerTest._
|
||||
|
||||
val magic = intSeqToByteString(List(0xff99))
|
||||
|
||||
"RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in {
|
||||
val trackerProbe = TestProbe()
|
||||
val connProbe = TestProbe()
|
||||
|
||||
val worldSize = 4
|
||||
|
||||
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
|
||||
trackerProbe.ref, connProbe.ref))
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||
|
||||
// send mock magic number
|
||||
fsm ! Tcp.Received(magic)
|
||||
connProbe.expectMsg(Tcp.Write(magic))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||
// ResumeReading should be seen once state transitions
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// send mock tracker command in fragments: the handler should be able to handle it.
|
||||
val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
|
||||
bufRank.putInt(0).putInt(worldSize).flip()
|
||||
|
||||
val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder())
|
||||
bufJobId.putInt(1).put(Array[Byte]('0')).flip()
|
||||
|
||||
val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder())
|
||||
bufCmd.putInt(5).put("start".getBytes()).flip()
|
||||
|
||||
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank))
|
||||
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId))
|
||||
|
||||
// the state should not change for incomplete command data.
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||
|
||||
// send the last fragment, and expect message at tracker actor.
|
||||
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
|
||||
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
|
||||
|
||||
val linkMap = new LinkMap(worldSize)
|
||||
val assignedRank = linkMap.assignRank(0)
|
||||
trackerProbe.reply(assignedRank)
|
||||
|
||||
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
|
||||
assignedRank.toByteBuffer(worldSize)
|
||||
)))
|
||||
|
||||
// reading should be suspended upon transitioning to BuildingLinkMap
|
||||
connProbe.expectMsg(Tcp.SuspendReading)
|
||||
// state should transition with according state data changes.
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// since the connection handler in test has rank 0, it will not have any nodes to connect to.
|
||||
fsm ! Tcp.Received(intSeqToByteString(List(0)))
|
||||
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
|
||||
|
||||
// return mock response to the connection handler
|
||||
val awaitConnPromise = Promise[AwaitingConnections]()
|
||||
awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef],
|
||||
fsm.underlyingActor.getNeighboringWorkers.size
|
||||
))
|
||||
fsm ! awaitConnPromise.future
|
||||
connProbe.expectMsg(Tcp.Write(
|
||||
intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size))
|
||||
))
|
||||
connProbe.expectMsg(Tcp.SuspendReading)
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// send mock error count (0)
|
||||
fsm ! Tcp.Received(intSeqToByteString(List(0)))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events.
|
||||
fsm ! Tcp.PeerClosed
|
||||
// state should not transition
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
|
||||
fsm ! Tcp.Received(intSeqToByteString(List(32768)))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2))
|
||||
|
||||
val handlerStopProbe = TestProbe()
|
||||
handlerStopProbe watch fsm
|
||||
|
||||
// simulate connections from other workers by mocking ReduceWaitCount commands
|
||||
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
|
||||
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
|
||||
trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0))
|
||||
handlerStopProbe.expectTerminated(fsm)
|
||||
|
||||
// all done.
|
||||
}
|
||||
|
||||
it should "forward print command to tracker" in {
|
||||
val trackerProbe = TestProbe()
|
||||
val connProbe = TestProbe()
|
||||
|
||||
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
|
||||
trackerProbe.ref, connProbe.ref))
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||
|
||||
fsm ! Tcp.Received(magic)
|
||||
connProbe.expectMsg(Tcp.Write(magic))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||
// ResumeReading should be seen once state transitions
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
|
||||
fsm ! Tcp.Received(printCmd.encode)
|
||||
|
||||
trackerProbe.expectMsg(printCmd)
|
||||
}
|
||||
|
||||
it should "handle fragmented print command without throwing exception" in {
|
||||
val trackerProbe = TestProbe()
|
||||
val connProbe = TestProbe()
|
||||
|
||||
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
|
||||
trackerProbe.ref, connProbe.ref))
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||
|
||||
fsm ! Tcp.Received(magic)
|
||||
connProbe.expectMsg(Tcp.Write(magic))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||
// ResumeReading should be seen once state transitions
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
val printCmd = WorkerTrackerPrint(0, 4, "0", "fragmented!")
|
||||
// 4 (rank: Int) + 4 (worldSize: Int) + (4+1) (jobId: String) + (4+5) (command: String) = 22
|
||||
val (partialMessage, remainder) = printCmd.encode.splitAt(22)
|
||||
|
||||
// make sure that the partialMessage in itself is a valid command
|
||||
val partialMsgBuf = ByteBuffer.allocate(22).order(ByteOrder.nativeOrder())
|
||||
partialMsgBuf.put(partialMessage.asByteBuffer)
|
||||
RabitWorkerHandler.StructTrackerCommand.verify(partialMsgBuf) shouldBe true
|
||||
|
||||
fsm ! Tcp.Received(partialMessage)
|
||||
fsm ! Tcp.Received(remainder)
|
||||
|
||||
trackerProbe.expectMsg(printCmd)
|
||||
}
|
||||
|
||||
it should "handle spill-over Tcp data correctly between state transition" in {
|
||||
val trackerProbe = TestProbe()
|
||||
val connProbe = TestProbe()
|
||||
|
||||
val worldSize = 4
|
||||
|
||||
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
|
||||
trackerProbe.ref, connProbe.ref))
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
|
||||
|
||||
// send mock magic number
|
||||
fsm ! Tcp.Received(magic)
|
||||
connProbe.expectMsg(Tcp.Write(magic))
|
||||
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
|
||||
// ResumeReading should be seen once state transitions
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// send mock tracker command in fragments: the handler should be able to handle it.
|
||||
val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder())
|
||||
bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0'))
|
||||
.putInt(5).put("start".getBytes())
|
||||
// spilled-over data
|
||||
.putInt(0).flip()
|
||||
|
||||
// send data with 4 extra bytes corresponding to the next state.
|
||||
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
|
||||
|
||||
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
|
||||
|
||||
val linkMap = new LinkMap(worldSize)
|
||||
val assignedRank = linkMap.assignRank(0)
|
||||
trackerProbe.reply(assignedRank)
|
||||
|
||||
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
|
||||
assignedRank.toByteBuffer(worldSize)
|
||||
)))
|
||||
|
||||
// reading should be suspended upon transitioning to BuildingLinkMap
|
||||
connProbe.expectMsg(Tcp.SuspendReading)
|
||||
// state should transition with according state data changes.
|
||||
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
|
||||
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
|
||||
connProbe.expectMsg(Tcp.ResumeReading)
|
||||
|
||||
// the handler should be able to handle spill-over data, and stash it until state transition.
|
||||
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user