merge 23Mar01

This commit is contained in:
amdsc21
2023-05-02 00:05:58 +02:00
258 changed files with 7471 additions and 5379 deletions

View File

@@ -33,16 +33,16 @@
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.8.3</flink.version>
<spark.version>3.1.1</spark.version>
<scala.version>2.12.8</scala.version>
<flink.version>1.17.0</flink.version>
<spark.version>3.4.0</spark.version>
<scala.version>2.12.17</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>3.3.5</hadoop.version>
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
<log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda>
<cudf.version>22.12.0</cudf.version>
<spark.rapids.version>22.12.0</spark.rapids.version>
<cudf.version>23.04.0</cudf.version>
<spark.rapids.version>23.04.0</spark.rapids.version>
<cudf.classifier>cuda11</cudf.classifier>
</properties>
<repositories>
@@ -374,7 +374,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>3.2.1</version>
<version>3.2.2</version>
<configuration>
<configLocation>checkstyle.xml</configLocation>
<failOnViolation>true</failOnViolation>
@@ -450,7 +450,7 @@
<plugins>
<plugin>
<artifactId>maven-project-info-reports-plugin</artifactId>
<version>3.4.2</version>
<version>3.4.3</version>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
@@ -469,7 +469,7 @@
<dependency>
<groupId>com.esotericsoftware</groupId>
<artifactId>kryo</artifactId>
<version>5.4.0</version>
<version>5.5.0</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
@@ -477,11 +477,6 @@
<version>${scala.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
@@ -495,13 +490,13 @@
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.0.8</version>
<version>3.2.15</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalactic</groupId>
<artifactId>scalactic_${scala.binary.version}</artifactId>
<version>3.0.8</version>
<version>3.2.15</version>
<scope>test</scope>
</dependency>
</dependencies>

View File

@@ -26,7 +26,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
@@ -37,12 +37,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
<version>${project.version}</version>
</dependency>
</dependencies>
</project>

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2021 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -62,8 +62,8 @@ public class BasicWalkThrough {
public static void main(String[] args) throws IOException, XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
@@ -112,7 +112,8 @@ public class BasicWalkThrough {
System.out.println("start build dmatrix from csr sparse data ...");
//build dmatrix from CSR Sparse Matrix
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DataLoader.CSRSparseData spData =
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
DMatrix.SparseType.CSR, 127);

View File

@@ -32,8 +32,8 @@ public class BoostFromPrediction {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@@ -30,7 +30,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
public class CrossValidation {
public static void main(String[] args) throws IOException, XGBoostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
//set params
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@@ -139,9 +139,9 @@ public class CustomObjective {
public static void main(String[] args) throws XGBoostError {
//load train mat (svmlight format)
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
//load valid mat (svmlight format)
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);

View File

@@ -29,9 +29,9 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader;
public class EarlyStopping {
public static void main(String[] args) throws IOException, XGBoostError {
DataLoader.CSRSparseData trainCSR =
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
DataLoader.CSRSparseData testCSR =
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test");
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test?format=libsvm");
Map<String, Object> paramMap = new HashMap<String, Object>() {
{

View File

@@ -32,8 +32,8 @@ public class ExternalMemory {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@@ -32,8 +32,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
public class GeneralizedLinearModel {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
//change booster to gblinear, so that we are fitting a linear model

View File

@@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
public class PredictFirstNtree {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
public class PredictLeafIndices {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@@ -0,0 +1,107 @@
/*
Copyright (c) 2014-2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.example.flink;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple13;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import ml.dmlc.xgboost4j.java.flink.XGBoost;
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
public class DistTrainWithFlinkExample {
static Tuple2<XGBoostModel, DataSet<Float[]>> runPrediction(
ExecutionEnvironment env,
java.nio.file.Path trainPath,
int percentage) throws Exception {
// reading data
final DataSet<Tuple2<Long, Tuple2<Vector, Double>>> data =
DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
final long size = data.count();
final long trainCount = Math.round(size * 0.01 * percentage);
final DataSet<Tuple2<Vector, Double>> trainData =
data
.filter(item -> item.f0 < trainCount)
.map(t -> t.f1)
.returns(TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>(){}));
final DataSet<Vector> testData =
data
.filter(tuple -> tuple.f0 >= trainCount)
.map(t -> t.f1.f0)
.returns(TypeInformation.of(new TypeHint<Vector>(){}));
// define parameters
HashMap<String, Object> paramMap = new HashMap<String, Object>(3);
paramMap.put("eta", 0.1);
paramMap.put("max_depth", 2);
paramMap.put("objective", "binary:logistic");
// number of iterations
final int round = 2;
// train the model
XGBoostModel model = XGBoost.train(trainData, paramMap, round);
DataSet<Float[]> predTest = model.predict(testData);
return new Tuple2<XGBoostModel, DataSet<Float[]>>(model, predTest);
}
private static MapOperator<Tuple13<Double, String, Double, Double, Double, Integer, Integer,
Integer, Integer, Integer, Integer, Integer, Integer>,
Tuple2<Vector, Double>> parseCsv(ExecutionEnvironment env, Path trainPath) {
return env.readCsvFile(trainPath.toString())
.ignoreFirstLine()
.types(Double.class, String.class, Double.class, Double.class, Double.class,
Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
Integer.class, Integer.class, Integer.class)
.map(DistTrainWithFlinkExample::mapFunction);
}
private static Tuple2<Vector, Double> mapFunction(Tuple13<Double, String, Double, Double, Double,
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer> tuple) {
final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
if (tuple.f1.contains("inf")) {
return new Tuple2<Vector, Double>(dense, 1.0);
} else {
return new Tuple2<Vector, Double>(dense, 0.0);
}
}
public static void main(String[] args) throws Exception {
final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
.findFirst().orElse("."));
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
Tuple2<XGBoostModel, DataSet<Float[]>> tuple2 = runPrediction(
env, parentPath.resolve("veterans_lung_cancer.csv"), 70
);
List<Float[]> list = tuple2.f1.collect();
System.out.println(list.size());
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -36,8 +36,8 @@ object BasicWalkThrough {
}
def main(args: Array[String]): Unit = {
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
@@ -76,7 +76,7 @@ object BasicWalkThrough {
// build dmatrix from CSR Sparse Matrix
println("start build dmatrix from csr sparse data ...")
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm")
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
JDMatrix.SparseType.CSR)
trainMax2.setLabel(spData.labels)

View File

@@ -24,8 +24,8 @@ object BoostFromPrediction {
def main(args: Array[String]): Unit = {
println("start running example to start from a initial prediction")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object CrossValidation {
def main(args: Array[String]): Unit = {
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
// set params
val params = new mutable.HashMap[String, Any]

View File

@@ -138,8 +138,8 @@ object CustomObjective {
}
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2

View File

@@ -25,8 +25,8 @@ object ExternalMemory {
// this is the only difference, add a # followed by a cache prefix name
// several cache file with the prefix will be generated
// currently only support convert from libsvm file
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@@ -27,8 +27,8 @@ import ml.dmlc.xgboost4j.scala.example.util.CustomEval
*/
object GeneralizedLinearModel {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
// specify parameters
// change booster to gblinear, so that we are fitting a linear model

View File

@@ -23,8 +23,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object PredictFirstNTree {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object PredictLeafIndices {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -15,27 +15,84 @@
*/
package ml.dmlc.xgboost4j.scala.example.flink
import ml.dmlc.xgboost4j.scala.flink.XGBoost
import org.apache.flink.api.scala.{ExecutionEnvironment, _}
import org.apache.flink.ml.MLUtils
import java.lang.{Double => JDouble, Long => JLong}
import java.nio.file.{Path, Paths}
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.ml.linalg.{Vector, Vectors}
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
import org.apache.flink.api.java.utils.DataSetUtils
object DistTrainWithFlink {
def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
// read trainining data
val trainData =
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
import scala.jdk.CollectionConverters._
private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
private val testDataTypeHint = TypeInformation.of(classOf[Vector])
private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
DataSetUtils.zipWithIndex(
env
.readCsvFile(trainPath.toString)
.ignoreFirstLine
.types(
classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
)
.map((row: Tuple13[Double, String, Double, Double, Double,
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
val dense = Vectors.dense(row.f2, row.f3, row.f4,
row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
val label = if (row.f1.contains("inf")) {
JDouble.valueOf(1.0)
} else {
JDouble.valueOf(0.0)
}
new Tuple2[Vector, JDouble](dense, label)
})
.returns(rowTypeHint)
)
}
private[flink] def runPrediction(trainPath: Path, percentage: Int)
(implicit env: ExecutionEnvironment):
(XGBoostModel, DataSet[Array[Float]]) = {
// read training data
val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
val trainSize = Math.round(0.01 * percentage * data.count())
val trainData: DataSet[Tuple2[Vector, JDouble]] =
data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
val testData: DataSet[Vector] =
data
.filter(d => d.f0 >= trainSize)
.map(_.f1.f0)
.returns(testDataTypeHint)
val paramMap = mapAsJavaMap(Map(
("eta", "0.1".asInstanceOf[AnyRef]),
("max_depth", "2"),
("objective", "binary:logistic"),
("verbosity", "1")
))
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(trainData, paramMap, round)
val predTest = model.predict(testData.map{x => x.vector})
model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
(model, result)
}
def main(args: Array[String]): Unit = {
implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val parentPath = Paths.get(args.headOption.getOrElse("."))
val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
val list = predTest.collect().asScala
println(list.length)
}
}

View File

@@ -0,0 +1,36 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.example.flink
import org.apache.flink.api.java.ExecutionEnvironment
import org.scalatest.Inspectors._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers._
import java.nio.file.Paths
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
private val data = parentPath.resolve("veterans_lung_cancer.csv")
test("Smoke test for scala flink example") {
val env = ExecutionEnvironment.createLocalEnvironment(1)
val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
val results = tuple2.f1.collect()
results should have size 41
forEvery(results)(item => item should have size 1)
}
}

View File

@@ -0,0 +1,37 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example.flink
import org.apache.flink.api.java.ExecutionEnvironment
import org.scalatest.Inspectors._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers._
import java.nio.file.Paths
import scala.jdk.CollectionConverters._
class DistTrainWithFlinkSuite extends AnyFunSuite {
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
private val data = parentPath.resolve("veterans_lung_cancer.csv")
test("Smoke test for scala flink example") {
implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
val results = result.collect().asScala
results should have size 41
forEvery(results)(item => item should have size 1)
}
}

View File

@@ -8,8 +8,11 @@
<artifactId>xgboost-jvm_2.12</artifactId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<artifactId>xgboost4j-flink_2.12</artifactId>
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
<properties>
<flink-ml.version>2.2.0</flink-ml.version>
</properties>
<build>
<plugins>
<plugin>
@@ -26,32 +29,22 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-scala_${scala.binary.version}</artifactId>
<artifactId>flink-clients</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-clients_${scala.binary.version}</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-ml_${scala.binary.version}</artifactId>
<version>${flink.version}</version>
<artifactId>flink-ml-servable-core</artifactId>
<version>${flink-ml.version}</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>3.3.5</version>
<version>${hadoop.version}</version>
</dependency>
</dependencies>

View File

@@ -0,0 +1,187 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.flink;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.RabitTracker;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class XGBoost {
private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);
private static class MapFunction
extends RichMapPartitionFunction<Tuple2<Vector, Double>, XGBoostModel> {
private final Map<String, Object> params;
private final int round;
private final Map<String, String> workerEnvs;
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
this.params = params;
this.round = round;
this.workerEnvs = workerEnvs;
}
public void mapPartition(java.lang.Iterable<Tuple2<Vector, Double>> it,
Collector<XGBoostModel> collector) throws XGBoostError {
workerEnvs.put(
"DMLC_TASK_ID",
String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask())
);
if (logger.isInfoEnabled()) {
logger.info("start with env: {}", workerEnvs.entrySet().stream()
.map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue()))
.collect(Collectors.joining(", "))
);
}
final Iterator<LabeledPoint> dataIter =
StreamSupport
.stream(it.spliterator(), false)
.map(VectorToPointMapper.INSTANCE)
.iterator();
if (dataIter.hasNext()) {
final DMatrix trainMat = new DMatrix(dataIter, null);
int numEarlyStoppingRounds =
Optional.ofNullable(params.get("numEarlyStoppingRounds"))
.map(x -> Integer.parseInt(x.toString()))
.orElse(0);
final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds);
collector.collect(new XGBoostModel(booster));
} else {
logger.warn("Nothing to train with.");
}
}
private Booster trainBooster(DMatrix trainMat,
int numEarlyStoppingRounds) throws XGBoostError {
Booster booster;
final Map<String, DMatrix> watches =
new HashMap<String, DMatrix>() {{ put("train", trainMat); }};
try {
Communicator.init(workerEnvs);
booster = ml.dmlc.xgboost4j.java.XGBoost
.train(
trainMat,
params,
round,
watches,
null,
null,
null,
numEarlyStoppingRounds);
} catch (XGBoostError xgbException) {
final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
logger.warn(
String.format("XGBooster worker %s has failed due to", identifier),
xgbException
);
throw xgbException;
} finally {
Communicator.shutdown();
}
return booster;
}
private static class VectorToPointMapper
implements Function<Tuple2<Vector, Double>, LabeledPoint> {
public static VectorToPointMapper INSTANCE = new VectorToPointMapper();
@Override
public LabeledPoint apply(Tuple2<Vector, Double> tuple) {
final SparseVector vector = tuple.f0.toSparse();
final double[] values = vector.values;
final int size = values.length;
final float[] array = new float[size];
for (int i = 0; i < size; i++) {
array[i] = (float) values[i];
}
return new LabeledPoint(
tuple.f1.floatValue(),
vector.size(),
vector.indices,
array);
}
}
}
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception {
final FileSystem fileSystem = FileSystem.get(new Configuration());
final Path f = new Path(modelPath);
try (FSDataInputStream opened = fileSystem.open(f)) {
return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened));
}
}
/**
* Train a xgboost model with link.
*
* @param dtrain The training data.
* @param params XGBoost parameters.
* @param numBoostRound Number of rounds to train.
*/
public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
Map<String, Object> params,
int numBoostRound) throws Exception {
final RabitTracker tracker =
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start(0L)) {
return dtrain
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
.reduce((x, y) -> x)
.collect()
.get(0);
} else {
throw new Error("Tracker cannot be started");
}
}
}

View File

@@ -0,0 +1,136 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.flink;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class XGBoostModel implements Serializable {
private static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);
private final Booster booster;
private final PredictorFunction predictorFunction;
public XGBoostModel(Booster booster) {
this.booster = booster;
this.predictorFunction = new PredictorFunction(booster);
}
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
*/
public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
}
public byte[] toByteArray(String format) throws XGBoostError {
return booster.toByteArray(format);
}
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
* @param format The model format (ubj, json, deprecated)
* @throws XGBoostError internal error
* @throws IOException save error
*/
public void saveModelAsHadoopFile(String modelPath, String format)
throws IOException, XGBoostError {
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
}
/**
* predict with the given DMatrix
*
* @param testSet the local test set represented as DMatrix
* @return prediction result
*/
public float[][] predict(DMatrix testSet) throws XGBoostError {
return booster.predict(testSet, true, 0);
}
/**
* Predict given vector dataset.
*
* @param data The dataset to be predicted.
* @return The prediction result.
*/
public DataSet<Float[]> predict(DataSet<Vector> data) {
return data.mapPartition(predictorFunction);
}
private static class PredictorFunction implements MapPartitionFunction<Vector, Float[]> {
private final Booster booster;
public PredictorFunction(Booster booster) {
this.booster = booster;
}
@Override
public void mapPartition(Iterable<Vector> it, Collector<Float[]> out) throws Exception {
final Iterator<LabeledPoint> dataIter =
StreamSupport.stream(it.spliterator(), false)
.map(Vector::toSparse)
.map(PredictorFunction::fromVector)
.iterator();
if (dataIter.hasNext()) {
final DMatrix data = new DMatrix(dataIter, null);
float[][] predictions = booster.predict(data, true, 2);
Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
} else {
logger.debug("Empty partition");
}
}
private static LabeledPoint fromVector(SparseVector vector) {
final int[] index = vector.indices;
final double[] value = vector.values;
int size = value.length;
final float[] values = new float[size];
for (int i = 0; i < size; i++) {
values[i] = (float) value[i];
}
return new LabeledPoint(0.0f, vector.size(), index, values);
}
}
}

View File

@@ -1,99 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.flink
import scala.collection.JavaConverters.asScalaIteratorConverter
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
import org.apache.commons.logging.LogFactory
import org.apache.flink.api.common.functions.RichMapPartitionFunction
import org.apache.flink.api.scala.{DataSet, _}
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.util.Collector
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
object XGBoost {
/**
* Helper map function to start the job.
*
* @param workerEnvs
*/
private class MapFunction(paramMap: Map[String, Any],
round: Int,
workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
val logger = LogFactory.getLog(this.getClass)
def mapPartition(it: java.lang.Iterable[LabeledVector],
collector: Collector[XGBoostModel]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
logger.info("start with env" + workerEnvs.toString)
Communicator.init(workerEnvs)
val mapper = (x: LabeledVector) => {
val (index, value) = x.vector.toSeq.unzip
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
}
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
val trainMat = new DMatrix(dataIter, null)
val watches = List("train" -> trainMat).toMap
val round = 2
val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
.map(_.toString.toInt).getOrElse(0)
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
earlyStoppingRound = numEarlyStoppingRounds)
Communicator.shutdown()
collector.collect(new XGBoostModel(booster))
}
}
val logger = LogFactory.getLog(this.getClass)
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
new XGBoostModel(
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
}
/**
* Train a xgboost model with link.
*
* @param dtrain The training data.
* @param params The parameters to XGBoost.
* @param round Number of rounds to train.
*/
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
if (tracker.start(0L)) {
dtrain
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
.reduce((x, y) => x).collect().head
} else {
throw new Error("Tracker cannot be started")
null
}
}
}

View File

@@ -1,67 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.flink
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.flink.api.scala.{DataSet, _}
import org.apache.flink.ml.math.Vector
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
class XGBoostModel (booster: Booster) extends Serializable {
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModelAsHadoopFile(modelPath: String): Unit = {
booster.saveModel(FileSystem
.get(new Configuration)
.create(new Path(modelPath)))
}
/**
* predict with the given DMatrix
* @param testSet the local test set represented as DMatrix
* @return prediction result
*/
def predict(testSet: DMatrix): Array[Array[Float]] = {
booster.predict(testSet, true, 0)
}
/**
* Predict given vector dataset.
*
* @param data The dataset to be predicted.
* @return The prediction result.
*/
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
(it: Iterator[Vector]) => {
val mapper = (x: Vector) => {
val (index, value) = x.toSeq.unzip
LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
}
val dataIter = for (x <- it) yield mapper(x)
val dmat = new DMatrix(dataIter, null)
this.booster.predict(dmat)
}
data.mapPartition(predictMap)
}
}

View File

@@ -38,22 +38,10 @@
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor_${scala.binary.version}</artifactId>
<version>2.6.20</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
<version>2.6.20</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.0.5</version>
<version>3.2.15</version>
<scope>provided</scope>
</dependency>
<dependency>

View File

@@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala
import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf.Table
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
class QuantileDMatrixSuite extends FunSuite {
class QuantileDMatrixSuite extends AnyFunSuite {
test("QuantileDMatrix test") {

View File

@@ -44,13 +44,6 @@
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>cudf</artifactId>
<version>${cudf.version}</version>
<classifier>${cudf.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.nvidia</groupId>
<artifactId>rapids-4-spark_${scala.binary.version}</artifactId>

View File

@@ -20,14 +20,15 @@ import java.nio.file.{Files, Path}
import java.sql.{Date, Timestamp}
import java.util.{Locale, TimeZone}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.{GpuTestUtils, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
trait GpuTestSuite extends FunSuite with TmpFolderSuite {
trait GpuTestSuite extends AnyFunSuite with TmpFolderSuite {
import SparkSessionHolder.withSparkSession
protected def getResourcePath(resource: String): String = {
@@ -200,7 +201,7 @@ trait GpuTestSuite extends FunSuite with TmpFolderSuite {
}
trait TmpFolderSuite extends BeforeAndAfterAll { self: FunSuite =>
trait TmpFolderSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2022 by Contributors
Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -22,7 +22,6 @@ import java.util.ServiceLoader
import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}
import ml.dmlc.xgboost4j.java.Communicator
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
@@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
@@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider {
private var batchCnt = 0
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Communicator.init(rabitEnv.asJava)
}
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
@@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider {
override def hasNext: Boolean = batchIterImpl.hasNext
override def next(): Row = {
val ret = batchIterImpl.next()
if (!batchIterImpl.hasNext) {
Communicator.shutdown()
}
ret
}
override def next(): Row = batchIterImpl.next()
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -23,7 +23,6 @@ import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@@ -44,21 +43,16 @@ import org.apache.spark.sql.SparkSession
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
*/
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
case class TrackerConf(workerConnectionTimeout: Long,
hostIp: String = "", pythonExec: String = "")
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
def apply(): TrackerConf = TrackerConf(0L)
}
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
@@ -349,11 +343,9 @@ object XGBoost extends Serializable {
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = trackerConf.trackerImpl match {
case "scala" => new RabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
case _ => new PyRabitTracker(nWorkers)
}
val tracker: IRabitTracker = new PyRabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
)
tracker
}

View File

@@ -22,11 +22,10 @@ import scala.util.Random
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
class CommunicatorRobustnessSuite extends FunSuite with PerTest {
class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
val classifier = new XGBoostClassifier(paramMap)
@@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
"tracker_conf" -> TrackerConf(0L, hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
@@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap1 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
@@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap2 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
@@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
}
}
test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test Communicator allreduce to validate Scala-implemented Rabit tracker") {
val vectorLength = 100
val rdd = sc.parallelize(
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
val rawData = rdd.mapPartitions { iter =>
Iterator(iter.toArray)
}.collect()
val maxVec = (0 until vectorLength).toArray.map { j =>
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
}
val allReduceResults = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
val arr = iter.toArray
val results = Communicator.allReduce(arr, Communicator.OpType.MAX)
Communicator.shutdown()
Iterator(results)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
allReduceResults.foreachPartition(() => _)
val byPartitionResults = allReduceResults.collect()
assert(byPartitionResults(0).length == vectorLength)
collectedAllReduceResults.put(byPartitionResults(0))
}
}
sparkThread.start()
assert(tracker.waitFor(0L) == 0)
sparkThread.join()
assert(collectedAllReduceResults.poll().sameElements(maxVec))
}
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
assert(tracker.waitFor(0) != 0)
}
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
val dummyTasks = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Communicator.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("test Scala RabitTracker's workerConnectionTimeout") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(500)
val trackerEnvs = tracker.getWorkerEnvs
val dummyTasks = rdd.mapPartitions { iter =>
val index = iter.next()
// simulate that the first worker cannot connect to tracker due to network issues.
if (index != 1) {
Communicator.init(trackerEnvs)
Thread.sleep(1000)
Communicator.shutdown()
}
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
// should fail due to connection timeout
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
" multiple times (ISSUE-4406)") {
val paramMap = Map(

View File

@@ -17,13 +17,13 @@
package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.linalg.Vectors
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import org.apache.spark.sql.functions._
class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Classifier)") {

View File

@@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.hadoop.fs.{FileSystem, Path}
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
Map[String, Any] = {

View File

@@ -18,12 +18,12 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.sql.functions._
import scala.util.Random
class FeatureSizeValidatingSuite extends FunSuite with PerTest {
class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
test("transform throwing exception if feature size of dataset is greater than model's") {
val modelPath = getClass.getResource("/model/0.82/model").getPath

View File

@@ -19,12 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import scala.util.Random
import org.apache.spark.SparkException
class MissingValueHandlingSuite extends FunSuite with PerTest {
class MissingValueHandlingSuite extends AnyFunSuite with PerTest {
test("dense vectors containing missing value") {
def buildDenseDataFrame(): DataFrame = {
val numRows = 100

View File

@@ -16,12 +16,13 @@
package ml.dmlc.xgboost4j.scala.spark
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.SparkException
import org.apache.spark.ml.param.ParamMap
class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",

View File

@@ -22,13 +22,14 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import scala.math.min
import scala.util.Random
import org.apache.commons.io.IOUtils
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)

View File

@@ -25,9 +25,9 @@ import scala.util.Random
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.functions._
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class PersistenceSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
val eval = new EvalError()

View File

@@ -19,9 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.{Files, Path}
import org.apache.spark.network.util.JavaUtils
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
trait TmpFolderPerSuite extends BeforeAndAfterAll { self: FunSuite =>
trait TmpFolderPerSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {

View File

@@ -22,13 +22,13 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.commons.io.IOUtils
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuite {
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"

View File

@@ -21,11 +21,11 @@ import ml.dmlc.xgboost4j.scala.Booster
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.SparkException
class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest {
class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val predictionErrorMin = 0.00001f
val maxFailure = 2;

View File

@@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
class XGBoostConfigureSuite extends FunSuite with PerTest {
class XGBoostConfigureSuite extends AnyFunSuite with PerTest {
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

View File

@@ -22,12 +22,12 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.{SparkException, TaskContext}
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.lit
class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("distributed training with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)

View File

@@ -23,11 +23,11 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row}
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {

View File

@@ -69,7 +69,7 @@ pom_template = """
<dependency>
<groupId>org.scalactic</groupId>
<artifactId>scalactic_${{scala.binary.version}}</artifactId>
<version>3.0.8</version>
<version>3.2.15</version>
<scope>test</scope>
</dependency>
<dependency>

View File

@@ -31,22 +31,10 @@
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor_${scala.binary.version}</artifactId>
<version>2.6.20</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
<version>2.6.20</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.0.5</version>
<version>3.2.15</version>
<scope>provided</scope>
</dependency>
</dependencies>

View File

@@ -1,195 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit
import java.net.{InetAddress, InetSocketAddress}
import akka.actor.ActorSystem
import akka.pattern.ask
import ml.dmlc.xgboost4j.java.{IRabitTracker, TrackerProperties}
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.{Failure, Success, Try}
/**
* Scala implementation of the Rabit tracker interface without Python dependency.
* The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
* (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
* failures.
*
* Note that this implementation is currently experimental, and should be used at your own risk.
*
* Example usage:
* {{{
* import scala.concurrent.duration._
*
* val tracker = new RabitTracker(32)
* // allow up to 10 minutes for all workers to connect to the tracker.
* tracker.start(10 minutes)
*
* /* ...
* launching workers in parallel
* ...
* */
*
* // wait for worker execution up to 6 hours.
* // providing a finite timeout prevents a long-running task from hanging forever in
* // catastrophic events, like the loss of an executor during model training.
* tracker.waitFor(6 hours)
* }}}
*
* @param numWorkers Number of distributed workers from which the tracker expects connections.
* @param port The minimum port number that the tracker binds to.
* If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
* @param maxPortTrials The maximum number of trials of socket binding, by sequentially
* increasing the port number.
*/
private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
maxPortTrials: Int = 1000)
extends IRabitTracker {
import scala.collection.JavaConverters._
require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
val system = ActorSystem.create("RabitTracker")
val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
private[this] val tcpBindingTimeout: Duration = 1 minute
var workerEnvs: Map[String, String] = Map.empty
override def uncaughtException(t: Thread, e: Throwable): Unit = {
handler ? RabitTrackerHandler.InterruptTracker(e)
}
/**
* Start the Rabit tracker.
*
* @param timeout The timeout for awaiting connections from worker nodes.
* Note that when used in Spark applications, because all Spark transformations are
* lazily executed, the I/O time for loading RDDs/DataFrames from external sources
* (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
* If the timeout value is too small, the Rabit tracker will likely timeout before workers
* establishing connections to the tracker, due to the overhead of loading data.
* Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
* running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
* @return Boolean flag indicating if the Rabit tracker starts successfully.
*/
private def start(timeout: Duration): Boolean = {
val hostAddress = Option(TrackerProperties.getInstance().getHostIp)
.map(InetAddress.getByName).getOrElse(InetAddress.getLocalHost)
handler ? RabitTrackerHandler.StartTracker(
new InetSocketAddress(hostAddress, port.getOrElse(0)), maxPortTrials, timeout)
// block by waiting for the actor to bind to a port
Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
.asInstanceOf[Future[Map[String, String]]]) match {
case Success(futurePortBound) =>
// The success of the Future is contingent on binding to an InetSocketAddress.
val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
if (isBound) {
workerEnvs = Await.result(futurePortBound, 0 nano)
}
isBound
case Failure(ex: Throwable) =>
false
}
}
/**
* Start the Rabit tracker.
*
* @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
* connections. If a non-positive value is provided, the tracker
* waits for incoming worker connections indefinitely.
* @return Boolean flag indicating if the Rabit tracker starts successfully.
*/
def start(connectionTimeoutMillis: Long): Boolean = {
if (connectionTimeoutMillis <= 0) {
start(Duration.Inf)
} else {
start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
}
}
def stop(): Unit = {
system.terminate()
}
/**
* Get a Map of necessary environment variables to initiate Rabit workers.
*
* @return HashMap containing tracker information.
*/
def getWorkerEnvs: java.util.Map[String, String] = {
new java.util.HashMap((workerEnvs ++ Map(
"DMLC_NUM_WORKER" -> numWorkers.toString,
"DMLC_NUM_SERVER" -> "0"
)).asJava)
}
/**
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
* This method blocks until timeout or task completion.
*
* @param atMost the maximum execution time for the workers. By default,
* the tracker waits for the workers indefinitely.
* @return 0 if the tasks complete successfully, and non-zero otherwise.
*/
private def waitFor(atMost: Duration): Int = {
// request the completion Future from the tracker actor
Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
.asInstanceOf[Future[Int]]) match {
case Success(futureCompleted) =>
// wait for all workers to complete synchronously.
val statusCode = Try(Await.result(futureCompleted, atMost)) match {
case Success(n) if n == numWorkers =>
IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
case Success(n) if n < numWorkers =>
IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
case Failure(e) =>
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
}
system.terminate()
statusCode
case Failure(ex: Throwable) =>
system.terminate()
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
}
}
/**
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
* This method blocks until timeout or task completion.
*
* @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
* non-positive number is given, the tracker waits indefinitely.
* @return 0 if the tasks complete successfully, and non-zero otherwise
*/
def waitFor(atMostMillis: Long): Int = {
if (atMostMillis <= 0) {
waitFor(Duration.Inf)
} else {
waitFor(Duration.fromNanos(atMostMillis * 1e6))
}
}
}

View File

@@ -1,361 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.handler
import java.net.InetSocketAddress
import java.util.UUID
import scala.concurrent.duration._
import scala.collection.mutable
import scala.concurrent.{Promise, TimeoutException}
import akka.io.{IO, Tcp}
import akka.actor._
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
import scala.util.{Failure, Random, Success, Try}
/** The Akka actor for handling and coordinating Rabit worker connections.
* This is the main actor for handling socket connections, interacting with the synchronous
* tracker interface, and resolving tree/ring/parent dependencies between workers.
*
* @param numWorkers Number of workers to track.
*/
private[scala] class RabitTrackerHandler(numWorkers: Int)
extends Actor with ActorLogging {
import context.system
import RabitWorkerHandler._
import RabitTrackerHandler._
private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
private[this] val promisedShutdownWorkers = Promise[Int]()
private[this] val tcpManager = IO(Tcp)
// resolves worker connection dependency.
val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
// workers that have sent "shutdown" signal
private[this] val shutdownWorkers = mutable.Set.empty[Int]
private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
private[this] var maxPortTrials = 0
private[this] var workerConnectionTimeout: Duration = Duration.Inf
private[this] var portTrials = 0
private[this] val startedWorkers = mutable.Set.empty[Int]
val linkMap = new LinkMap(numWorkers)
def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
rank match {
case r if r >= 0 => Some(r)
case _ =>
jobId match {
case "NULL" => None
case jid => jobToRankMap.get(jid)
}
}
}
/**
* Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
* by the RabitWorkerHandler.
*
* @param event Generic Tcp.Event
*/
private def handleTcpEvents(event: Tcp.Event): Unit = event match {
case Tcp.Bound(local) =>
// expect all workers to connect within timeout
log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
log.info(s"Worker connection timeout is $workerConnectionTimeout.")
context.setReceiveTimeout(workerConnectionTimeout)
promisedWorkerEnvs.success(Map(
"DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
"DMLC_TRACKER_PORT" -> local.getPort.toString,
// not required because the world size will be communicated to the
// worker node after the rank is assigned.
"rabit_world_size" -> numWorkers.toString
))
case Tcp.CommandFailed(cmd: Tcp.Bind) =>
if (portTrials < maxPortTrials) {
portTrials += 1
tcpManager ! Tcp.Bind(self,
new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
backlog = 256)
}
case Tcp.Connected(remote, local) =>
log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
// revoke timeout if all workers have connected.
val workerHandler = context.actorOf(RabitWorkerHandler.props(
remote.getAddress.getHostAddress, numWorkers, self, sender()
), s"ConnectionHandler-${UUID.randomUUID().toString}")
val connection = sender()
connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
}
/**
* Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
* to interact with the tracker interface.
*
* @param trackerMsg control messages sent by RabitTracker class.
*/
private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
trackerMsg match {
case msg: StartTracker =>
maxPortTrials = msg.maxPortTrials
workerConnectionTimeout = msg.connectionTimeout
// if the port number is missing, try binding to a random ephemeral port.
if (msg.addr.getPort == 0) {
tcpManager ! Tcp.Bind(self,
new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
backlog = 256)
} else {
tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
}
sender() ! true
case RequestBoundFuture =>
sender() ! promisedWorkerEnvs.future
case RequestCompletionFuture =>
sender() ! promisedShutdownWorkers.future
case InterruptTracker(e) =>
log.error(e, "Uncaught exception thrown by worker.")
// make sure that waitFor() does not hang indefinitely.
promisedShutdownWorkers.failure(e)
context.stop(self)
}
/**
* Handles messages sent by child actors representing connecting Rabit workers, by brokering
* messages to the dependency resolver, and processing worker commands.
*
* @param workerMsg Message sent by RabitWorkerHandler actors.
*/
private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
case req @ RequestAwaitConnWorkers(_, _) =>
// since the requester may request to connect to other workers
// that have not fully set up, delegate this request to the
// dependency resolver which handles the dependencies properly.
resolver forward req
// ---- Rabit worker commands: start/recover/shutdown/print ----
case WorkerTrackerPrint(_, _, _, msg) =>
log.info(msg.trim)
case WorkerShutdown(rank, _, _) =>
assert(rank >= 0, "Invalid rank.")
assert(!shutdownWorkers.contains(rank))
shutdownWorkers.add(rank)
log.info(s"Received shutdown signal from $rank")
if (shutdownWorkers.size == numWorkers) {
promisedShutdownWorkers.success(shutdownWorkers.size)
}
case WorkerRecover(prevRank, worldSize, jobId) =>
assert(prevRank >= 0)
sender() ! linkMap.assignRank(prevRank)
case WorkerStart(rank, worldSize, jobId) =>
assert(worldSize == numWorkers || worldSize == -1,
s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
)
Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
case Success(wkRank) =>
if (jobId != "NULL") {
jobToRankMap.put(jobId, wkRank)
}
val assignedRank = linkMap.assignRank(wkRank)
sender() ! assignedRank
resolver ! assignedRank
log.info("Received start signal from " +
s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
case Failure(ex: IndexOutOfBoundsException) =>
// More than worldSize workers have connected, likely due to executor loss.
// Since Rabit currently does not support crash recovery (because the Allreduce results
// are not cached by the tracker, and because existing workers cannot reestablish
// connections to newly spawned executor/worker), the most reasonble action here is to
// interrupt the tracker immediate with failure state.
log.error("Received invalid start signal from " +
s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
)
promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
" received from worker, likely due to executor loss."))
case Failure(ex) =>
log.error(ex, "Unexpected error")
promisedShutdownWorkers.failure(ex)
}
// ---- Dependency resolving related messages ----
case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
log.info(s"Worker $host (rank: $rank) has started.")
resolver forward msg
startedWorkers.add(rank)
if (startedWorkers.size == numWorkers) {
log.info("All workers have started.")
}
case req @ DropFromWaitingList(_) =>
// all peer workers in dependency link map have connected;
// forward message to resolver to update dependencies.
resolver forward req
case _ =>
}
def receive: Actor.Receive = {
case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
case akka.actor.ReceiveTimeout =>
if (startedWorkers.size < numWorkers) {
promisedShutdownWorkers.failure(
new TimeoutException("Timed out waiting for workers to connect: " +
s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
)
context.stop(self)
}
context.setReceiveTimeout(Duration.Undefined)
}
}
/**
* Resolve the dependency between nodes as they connect to the tracker.
* The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
* and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
* connections by workers, this dependency constraint assumes that a worker node connects first
* is likely to finish setup first.
*/
private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
import RabitWorkerHandler._
context.watch(handler)
case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
// worker nodes that have connected, but have not send WorkerStarted message.
private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
private val startedWorkers = mutable.Set.empty[Int]
// worker nodes that have started, and await for connections.
private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
val connSet = awaitConnWorkers.toMap
.filterKeys(k => linkSet.contains(k))
AwaitingConnections(connSet, linkSet.size - connSet.size)
}
def receive: Actor.Receive = {
// a copy of the AssignedRank message that is also sent to the worker
case AssignedRank(rank, tree_neighbors, ring, parent) =>
// the workers that the worker of given `rank` depends on:
// worker of rank K only depends on workers with rank smaller than K.
val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
.filter{ r => r != -1 && r < rank}
log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
dependencyMap.put(rank, dependentWorkers)
case RequestAwaitConnWorkers(rank, toConnectSet) =>
val promise = Promise[AwaitingConnections]()
assert(dependencyMap.contains(rank))
val updatedDependency = dependencyMap(rank) diff startedWorkers
if (updatedDependency.isEmpty) {
// all dependencies are satisfied
log.debug(s"Rank $rank has all dependencies satisfied.")
promise.success(awaitingWorkers(toConnectSet))
} else {
log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
// promise is pending fulfillment due to unresolved dependency
pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
}
sender() ! promise.future
case WorkerStarted(_, started, awaitingAcceptance) =>
startedWorkers.add(started)
if (awaitingAcceptance > 0) {
awaitConnWorkers.put(started, sender())
}
// remove the started rank from all dependencies.
dependencyMap.remove(started)
dependencyMap.foreach { case (r, dset) =>
val updatedDependency = dset diff startedWorkers
// fulfill the future if all dependencies are met (started.)
if (updatedDependency.isEmpty) {
log.debug(s"Rank $r has all dependencies satisfied.")
pendingFulfillment.remove(r).map{
case Fulfillment(toConnectSet, promise) =>
promise.success(awaitingWorkers(toConnectSet))
}
}
dependencyMap.update(r, updatedDependency)
}
case DropFromWaitingList(rank) =>
assert(awaitConnWorkers.remove(rank).isDefined)
case Terminated(ref) =>
if (ref.equals(handler)) {
context.stop(self)
}
}
}
private[scala] object RabitTrackerHandler {
// Messages sent by RabitTracker to this RabitTrackerHandler actor
trait TrackerControlMessage
case object RequestCompletionFuture extends TrackerControlMessage
case object RequestBoundFuture extends TrackerControlMessage
// Start the Rabit tracker at given socket address awaiting worker connections.
// All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
// shut down due to timeout.
case class StartTracker(addr: InetSocketAddress,
maxPortTrials: Int,
connectionTimeout: Duration) extends TrackerControlMessage
// To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
// driver for the distributed training.
case class InterruptTracker(e: Throwable) extends TrackerControlMessage
def props(numWorkers: Int): Props =
Props(new RabitTrackerHandler(numWorkers))
}

View File

@@ -1,467 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.handler
import java.nio.{ByteBuffer, ByteOrder}
import akka.io.Tcp
import akka.actor._
import akka.util.ByteString
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Try
/**
* Actor to handle socket communication from worker node.
* To handle fragmentation in received data, this class acts like a FSM
* (finite-state machine) to keep track of the internal states.
*
* @param host IP address of the remote worker
* @param worldSize number of total workers
* @param tracker the RabitTrackerHandler actor reference
*/
private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
connection: ActorRef)
extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
with ActorLogging with Stash {
import RabitWorkerHandler._
import RabitTrackerHelpers._
private[this] var rank: Int = 0
private[this] var port: Int = 0
// indicate if the connection is transient (like "print" or "shutdown")
private[this] var transient: Boolean = false
private[this] var peerClosed: Boolean = false
// number of workers pending acceptance of current worker
private[this] var awaitingAcceptance: Int = 0
private[this] var neighboringWorkers = Set.empty[Int]
// TODO: use a single memory allocation to host all buffers,
// including the transient ones for writing.
private[this] val readBuffer = ByteBuffer.allocate(4096)
.order(ByteOrder.nativeOrder())
// in case the received message is longer than needed,
// stash the spilled over part in this buffer, and send
// to self when transition occurs.
private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
.order(ByteOrder.nativeOrder())
// when setup is complete, need to notify peer handlers
// to reduce the awaiting-connection counter.
private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
private def resetBuffers(): Unit = {
readBuffer.clear()
if (spillOverBuffer.position() > 0) {
spillOverBuffer.flip()
self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
spillOverBuffer.clear()
}
}
private def stashSpillOver(buf: ByteBuffer): Unit = {
if (buf.remaining() > 0) spillOverBuffer.put(buf)
}
def getNeighboringWorkers: Set[Int] = neighboringWorkers
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
readBuffer.flip()
val rank = readBuffer.getInt()
val worldSize = readBuffer.getInt()
val jobId = readBuffer.getString
val command = readBuffer.getString
val trackerCommand = command match {
case "start" => WorkerStart(rank, worldSize, jobId)
case "shutdown" =>
transient = true
WorkerShutdown(rank, worldSize, jobId)
case "recover" =>
require(rank >= 0, "Invalid rank for recovering worker.")
WorkerRecover(rank, worldSize, jobId)
case "print" =>
transient = true
WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString)
}
stashSpillOver(readBuffer)
trackerCommand
}
startWith(AwaitingHandshake, DataStruct())
when(AwaitingHandshake) {
case Event(Tcp.Received(magic), _) =>
assert(magic.length == 4)
val purportedMagic = magic.asNativeOrderByteBuffer.getInt
assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
// echo back the magic number
connection ! Tcp.Write(magic)
goto(AwaitingCommand) using StructTrackerCommand
}
when(AwaitingCommand) {
case Event(Tcp.Received(bytes), validator) =>
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
if (validator.verify(readBuffer)) {
Try(decodeCommand(readBuffer)) match {
case scala.util.Success(decodedCommand) =>
tracker ! decodedCommand
case scala.util.Failure(th: java.nio.BufferUnderflowException) =>
// BufferUnderflowException would occur if the message to print has not arrived yet.
// Do nothing, wait for next Tcp.Received event
case scala.util.Failure(th: Throwable) => throw th
}
}
stay
// when rank for a worker is assigned, send encoded rank information
// back to worker over Tcp socket.
case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
rank = assignedRank
// ranks from the ring
val ringRanks = List(
// ringPrev
if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
// ringNext
if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
)
// update the set of all linked workers to current worker.
neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
// to prevent reading before state transition
connection ! Tcp.SuspendReading
goto(BuildingLinkMap) using StructNodes
}
when(BuildingLinkMap) {
case Event(Tcp.Received(bytes), validator) =>
bytes.asByteBuffers.foreach { buf =>
readBuffer.put(buf)
}
if (validator.verify(readBuffer)) {
readBuffer.flip()
// for a freshly started worker, numConnected should be 0.
val numConnected = readBuffer.getInt()
val toConnectSet = neighboringWorkers.diff(
(0 until numConnected).map { index => readBuffer.getInt() }.toSet)
// check which workers are currently awaiting connections
tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
}
stay
// got a Future from the tracker (resolver) about workers that are
// currently awaiting connections (particularly from this node.)
case Event(future: Future[_], _) =>
// blocks execution until all dependencies for current worker is resolved.
Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
// numNotReachable is the number of workers that currently
// cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
// connections from those currently non-reachable nodes in the future.
case AwaitingConnections(waitConnNodes, numNotReachable) =>
log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
buf.putInt(waitConnNodes.size).putInt(numNotReachable)
buf.flip()
// cache this message until the final state (SetupComplete)
pendingAcknowledgement = Some(AcknowledgeAcceptance(
waitConnNodes, numNotReachable))
connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
if (waitConnNodes.isEmpty) {
connection ! Tcp.SuspendReading
goto(AwaitingErrorCount)
}
else {
waitConnNodes.foreach { case (peerRank, peerRef) =>
peerRef ! RequestWorkerHostPort
}
// a countdown for DivulgedHostPort messages.
stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
}
}
case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
val hostBytes = peerHost.getBytes()
val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
.order(ByteOrder.nativeOrder())
buffer.putInt(peerHost.length).put(hostBytes)
.putInt(peerPort).putInt(peerRank)
buffer.flip()
connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
if (data.counter == 0) {
// to prevent reading before state transition
connection ! Tcp.SuspendReading
goto(AwaitingErrorCount)
}
else {
stay using data.decrement()
}
}
when(AwaitingErrorCount) {
case Event(Tcp.Received(numErrors), _) =>
val buf = numErrors.asNativeOrderByteBuffer
buf.getInt match {
case 0 =>
stashSpillOver(buf)
goto(AwaitingPortNumber)
case _ =>
stashSpillOver(buf)
goto(BuildingLinkMap) using StructNodes
}
}
when(AwaitingPortNumber) {
case Event(Tcp.Received(assignedPort), _) =>
assert(assignedPort.length == 4)
port = assignedPort.asNativeOrderByteBuffer.getInt
log.debug(s"Rank $rank listening @ $host:$port")
// wait until the worker closes connection.
if (peerClosed) goto(SetupComplete) else stay
case Event(Tcp.PeerClosed, _) =>
peerClosed = true
if (port == 0) stay else goto(SetupComplete)
}
when(SetupComplete) {
case Event(ReduceWaitCount(count: Int), _) =>
awaitingAcceptance -= count
// check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
if (awaitingAcceptance == 0 && peerClosed) {
tracker ! DropFromWaitingList(rank)
// no longer needed.
context.stop(self)
}
stay
case Event(AcknowledgeAcceptance(peers, numBad), _) =>
awaitingAcceptance = numBad
tracker ! WorkerStarted(host, rank, awaitingAcceptance)
peers.values.foreach { peer =>
peer ! ReduceWaitCount(1)
}
if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
stay
// can only divulge the complete host and port information
// when this worker is declared fully connected (otherwise
// port information is still missing.)
case Event(RequestWorkerHostPort, _) =>
sender() ! DivulgedWorkerHostPort(rank, host, port)
stay
}
onTransition {
// reset buffer when state transitions as data becomes stale
case _ -> SetupComplete =>
connection ! Tcp.ResumeReading
resetBuffers()
if (pendingAcknowledgement.isDefined) {
self ! pendingAcknowledgement.get
}
case _ =>
connection ! Tcp.ResumeReading
resetBuffers()
}
// default message handler
whenUnhandled {
case Event(Tcp.PeerClosed, _) =>
peerClosed = true
if (transient) context.stop(self)
stay
}
}
private[scala] object RabitWorkerHandler {
val MAGIC_NUMBER = 0xff99
// Finite states of this actor, which acts like a FSM.
// The following states are defined in order as the FSM progresses.
sealed trait State
// [1] Initial state, awaiting worker to send magic number per protocol.
case object AwaitingHandshake extends State
// [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
case object AwaitingCommand extends State
// [3] Brokers connections between workers per ring/tree/parent link map.
case object BuildingLinkMap extends State
// [4] A transient state in which the worker reports the number of errors in establishing
// connections to other peer workers. If no errors, transition to next state.
case object AwaitingErrorCount extends State
// [5] Awaiting the worker to report its port number for accepting connections from peer workers.
// This port number information is later forwarded to linked workers.
case object AwaitingPortNumber extends State
// [6] Final state after completing the setup with the connecting worker. At this stage, the
// worker will have closed the Tcp connection. The actor remains alive to handle messages from
// peer actors representing workers with pending setups.
case object SetupComplete extends State
sealed trait DataField
case object IntField extends DataField
// an integer preceding the actual string
case object StringField extends DataField
case object IntSeqField extends DataField
object DataStruct {
def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
}
// Internal data pertaining to individual state, used to verify the validity of packets sent by
// workers.
case class DataStruct(fields: Seq[DataField], counter: Int) {
/**
* Validate whether the provided buffer is complete (i.e., contains
* all data fields specified for this DataStruct.)
*
* @param buf a byte buffer containing received data.
*/
def verify(buf: ByteBuffer): Boolean = {
if (fields.isEmpty) return true
val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
dupBuf.flip()
Try(fields.foldLeft(true) {
case (complete, field) =>
val remBytes = dupBuf.remaining()
complete && (remBytes > 0) && (remBytes >= (field match {
case IntField =>
dupBuf.position(dupBuf.position() + 4)
4
case StringField =>
val strLen = dupBuf.getInt
dupBuf.position(dupBuf.position() + strLen)
4 + strLen
case IntSeqField =>
val seqLen = dupBuf.getInt
dupBuf.position(dupBuf.position() + seqLen * 4)
4 + seqLen * 4
}))
}).getOrElse(false)
}
def increment(): DataStruct = DataStruct(fields, counter + 1)
def decrement(): DataStruct = DataStruct(fields, counter - 1)
}
val StructNodes = DataStruct(List(IntSeqField), 0)
val StructTrackerCommand = DataStruct(List(
IntField, IntField, StringField, StringField
), 0)
// ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
// RabitWorkerHandler --> RabitTrackerHandler
sealed trait RabitWorkerRequest
// RabitWorkerHandler <-- RabitTrackerHandler
sealed trait RabitWorkerResponse
// Representations of decoded worker commands.
abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
def rank: Int
def worldSize: Int
def jobId: String
def encode: ByteString = {
val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
.order(ByteOrder.nativeOrder())
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
.putInt(command.length).put(command.getBytes()).flip()
ByteString.fromByteBuffer(buf)
}
}
case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("start")
case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("shutdown")
case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("recover")
case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
extends TrackerCommand("print") {
override def encode: ByteString = {
val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
.order(ByteOrder.nativeOrder())
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
.putInt(command.length).put(command.getBytes())
.putInt(msg.length).put(msg.getBytes()).flip()
ByteString.fromByteBuffer(buf)
}
}
// Request to remove the worker of given rank from the list of workers awaiting peer connections.
case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
// Notify the tracker that the worker of given rank has finished setup and started.
case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
extends RabitWorkerRequest
// Request the set of workers to connect to, according to the LinkMap structure.
case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
extends RabitWorkerRequest
// Request, from the tracker, the set of nodes to connect.
case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
extends RabitWorkerResponse
// ---- Messages between ConnectionHandler actors ----
sealed trait IntraWorkerMessage
// Notify neighboring workers to decrease the counter of awaiting workers by `count`.
case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
// Request host and port information from peer ConnectionHandler actors (acting on behave of
// connecting workers.) This message will be brokered by RabitTrackerHandler.
case object RequestWorkerHostPort extends IntraWorkerMessage
// Response to the above request
case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
// A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
extends IntraWorkerMessage
// ---- End of message definitions ----
def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
}
}

View File

@@ -1,136 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.util
import java.nio.{ByteBuffer, ByteOrder}
/**
* The assigned rank to a connecting Rabit worker, along with the information of the ranks of
* its linked peer workers, which are critical to perform Allreduce.
* When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
* client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
* with this class as a message, which is later encoded to byte string, and sent over socket
* connection to the worker client.
*
* @param rank assigned rank (ranked by worker connection order: first worker connecting to the
* tracker is assigned rank 0, second with rank 1, etc.)
* @param neighbors ranks of neighboring workers in a tree map.
* @param ring ranks of neighboring workers in a ring map.
* @param parent rank of the parent worker.
*/
private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
ring: (Int, Int), parent: Int) {
/**
* Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
* client.
* @param worldSize the number of total distributed workers. Must match `numWorkers` used in
* LinkMap.
* @return a ByteBuffer containing encoded data.
*/
def toByteBuffer(worldSize: Int): ByteBuffer = {
val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
// neighbors in tree structure
neighbors.foreach { n => buffer.putInt(n) }
buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
buffer.flip()
buffer
}
}
private[rabit] class LinkMap(numWorkers: Int) {
private def getNeighbors(rank: Int): Seq[Int] = {
val rank1 = rank + 1
Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
r >= 0 && r < numWorkers
}
}
/**
* Construct a ring structure that tends to share nodes with the tree.
*
* @param treeMap
* @param parentMap
* @param rank
* @return Seq[Int] instance starting from rank.
*/
private def constructShareRing(treeMap: Map[Int, Seq[Int]],
parentMap: Map[Int, Int],
rank: Int = 0): Seq[Int] = {
treeMap(rank).toSet - parentMap(rank) match {
case emptySet if emptySet.isEmpty =>
List(rank)
case connectionSet =>
connectionSet.zipWithIndex.foldLeft(List(rank)) {
case (ringSeq, (v, cnt)) =>
val vConnSeq = constructShareRing(treeMap, parentMap, v)
vConnSeq match {
case vconn if vconn.size == cnt + 1 =>
ringSeq ++ vconn.reverse
case vconn =>
ringSeq ++ vconn
}
}
}
}
/**
* Construct a ring connection used to recover local data.
*
* @param treeMap
* @param parentMap
*/
private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
assert(parentMap(0) == -1)
val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
assert(sharedRing.length == treeMap.size)
(0 until numWorkers).map { r =>
val rPrev = (r + numWorkers - 1) % numWorkers
val rNext = (r + 1) % numWorkers
sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
}.toMap
}
private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
case ((rmap, k), i) =>
val kNext = ringMap_(k)._2
(rmap ++ Map(kNext -> (i + 1)), kNext)
}._1
val ringMap = ringMap_.map {
case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
}
val treeMap = treeMap_.map {
case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
}
val parentMap = parentMap_.map {
case (k, v) if k == 0 =>
rMap_(k) -> -1
case (k, v) =>
rMap_(k) -> rMap_(v)
}
def assignRank(rank: Int): AssignedRank = {
AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
}
}

View File

@@ -1,39 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.util
import java.nio.{ByteOrder, ByteBuffer}
import akka.util.ByteString
private[rabit] object RabitTrackerHelpers {
implicit class ByteStringHelplers(bs: ByteString) {
// Java by default uses big endian. Enforce native endian so that
// the byte order is consistent with the workers.
def asNativeOrderByteBuffer: ByteBuffer = {
bs.asByteBuffer.order(ByteOrder.nativeOrder())
}
}
implicit class ByteBufferHelpers(buf: ByteBuffer) {
def getString: String = {
val len = buf.getInt()
val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
buf.get(stringBuffer.array(), 0, len)
new String(stringBuffer.array(), "utf-8")
}
}
}

View File

@@ -30,8 +30,8 @@ import org.junit.Test;
* @author hzx
*/
public class BoosterImplTest {
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1";
private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1";
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1&format=libsvm";
public static class EvalError implements IEvaluation {
@Override

View File

@@ -4,7 +4,7 @@
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@@ -88,7 +88,7 @@ public class DMatrixTest {
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
DMatrix dmat = new DMatrix(filePath);
DMatrix dmat = new DMatrix(filePath + "?format=libsvm");
//get label
float[] labels = dmat.getLabel();
//check length

View File

@@ -20,12 +20,12 @@ import java.util.Arrays
import scala.util.Random
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
class DMatrixSuite extends FunSuite {
class DMatrixSuite extends AnyFunSuite {
test("create DMatrix from File") {
val dmat = new DMatrix("../../demo/data/agaricus.txt.test")
val dmat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
// get label
val labels: Array[Float] = dmat.getLabel
// check length

View File

@@ -20,11 +20,11 @@ import java.io.{FileOutputStream, FileInputStream, File}
import junit.framework.TestCase
import org.apache.commons.logging.LogFactory
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.java.XGBoostError
class ScalaBoosterImplSuite extends FunSuite {
class ScalaBoosterImplSuite extends AnyFunSuite {
private class EvalError extends EvalTrait {
@@ -95,8 +95,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("basic operation of booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val booster = trainBooster(trainMat, testMat)
val predicts = booster.predict(testMat, true)
@@ -106,8 +106,8 @@ class ScalaBoosterImplSuite extends FunSuite {
test("save/load model with path") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
@@ -123,8 +123,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("save/load model with stream") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
@@ -139,7 +139,7 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("cross validation") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val params = List("eta" -> "1.0", "max_depth" -> "3", "silent" -> "1", "nthread" -> "6",
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
@@ -148,8 +148,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo depthwise") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "3", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "eval_metric" -> "auc").toMap
@@ -158,8 +158,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo lossguide") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "3", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "auc").toMap
@@ -168,8 +168,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo lossguide with max bin") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "3", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
@@ -179,8 +179,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo depthwidth with max depth") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
@@ -190,8 +190,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo depthwidth with max depth and max bin") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
@@ -201,7 +201,7 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test training from existing model in scala") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val paramMap = List("max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
@@ -213,8 +213,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test getting number of features from a booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val booster = trainBooster(trainMat, testMat)
TestCase.assertEquals(booster.getNumFeature, 127)

View File

@@ -1,255 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit
import java.nio.{ByteBuffer, ByteOrder}
import akka.actor.{ActorRef, ActorSystem}
import akka.io.Tcp
import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe}
import akka.util.ByteString
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._
import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpecLike, Matchers}
import scala.concurrent.Promise
object RabitTrackerConnectionHandlerTest {
def intSeqToByteString(seq: Seq[Int]): ByteString = {
val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder())
seq.foreach { i => buf.putInt(i) }
buf.flip()
ByteString.fromByteBuffer(buf)
}
}
@RunWith(classOf[JUnitRunner])
class RabitTrackerConnectionHandlerTest
extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest"))
with FlatSpecLike with Matchers with ImplicitSender {
import RabitTrackerConnectionHandlerTest._
val magic = intSeqToByteString(List(0xff99))
"RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val worldSize = 4
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
// send mock magic number
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
// send mock tracker command in fragments: the handler should be able to handle it.
val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
bufRank.putInt(0).putInt(worldSize).flip()
val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder())
bufJobId.putInt(1).put(Array[Byte]('0')).flip()
val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder())
bufCmd.putInt(5).put("start".getBytes()).flip()
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank))
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId))
// the state should not change for incomplete command data.
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
// send the last fragment, and expect message at tracker actor.
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
val linkMap = new LinkMap(worldSize)
val assignedRank = linkMap.assignRank(0)
trackerProbe.reply(assignedRank)
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
assignedRank.toByteBuffer(worldSize)
)))
// reading should be suspended upon transitioning to BuildingLinkMap
connProbe.expectMsg(Tcp.SuspendReading)
// state should transition with according state data changes.
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
connProbe.expectMsg(Tcp.ResumeReading)
// since the connection handler in test has rank 0, it will not have any nodes to connect to.
fsm ! Tcp.Received(intSeqToByteString(List(0)))
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
// return mock response to the connection handler
val awaitConnPromise = Promise[AwaitingConnections]()
awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef],
fsm.underlyingActor.getNeighboringWorkers.size
))
fsm ! awaitConnPromise.future
connProbe.expectMsg(Tcp.Write(
intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size))
))
connProbe.expectMsg(Tcp.SuspendReading)
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount
connProbe.expectMsg(Tcp.ResumeReading)
// send mock error count (0)
fsm ! Tcp.Received(intSeqToByteString(List(0)))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
connProbe.expectMsg(Tcp.ResumeReading)
// simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events.
fsm ! Tcp.PeerClosed
// state should not transition
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
fsm ! Tcp.Received(intSeqToByteString(List(32768)))
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
connProbe.expectMsg(Tcp.ResumeReading)
trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2))
val handlerStopProbe = TestProbe()
handlerStopProbe watch fsm
// simulate connections from other workers by mocking ReduceWaitCount commands
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
fsm ! RabitWorkerHandler.ReduceWaitCount(1)
trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0))
handlerStopProbe.expectTerminated(fsm)
// all done.
}
it should "forward print command to tracker" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
fsm ! Tcp.Received(printCmd.encode)
trackerProbe.expectMsg(printCmd)
}
it should "handle fragmented print command without throwing exception" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
val printCmd = WorkerTrackerPrint(0, 4, "0", "fragmented!")
// 4 (rank: Int) + 4 (worldSize: Int) + (4+1) (jobId: String) + (4+5) (command: String) = 22
val (partialMessage, remainder) = printCmd.encode.splitAt(22)
// make sure that the partialMessage in itself is a valid command
val partialMsgBuf = ByteBuffer.allocate(22).order(ByteOrder.nativeOrder())
partialMsgBuf.put(partialMessage.asByteBuffer)
RabitWorkerHandler.StructTrackerCommand.verify(partialMsgBuf) shouldBe true
fsm ! Tcp.Received(partialMessage)
fsm ! Tcp.Received(remainder)
trackerProbe.expectMsg(printCmd)
}
it should "handle spill-over Tcp data correctly between state transition" in {
val trackerProbe = TestProbe()
val connProbe = TestProbe()
val worldSize = 4
val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
trackerProbe.ref, connProbe.ref))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
// send mock magic number
fsm ! Tcp.Received(magic)
connProbe.expectMsg(Tcp.Write(magic))
fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
// ResumeReading should be seen once state transitions
connProbe.expectMsg(Tcp.ResumeReading)
// send mock tracker command in fragments: the handler should be able to handle it.
val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder())
bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0'))
.putInt(5).put("start".getBytes())
// spilled-over data
.putInt(0).flip()
// send data with 4 extra bytes corresponding to the next state.
fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
val linkMap = new LinkMap(worldSize)
val assignedRank = linkMap.assignRank(0)
trackerProbe.reply(assignedRank)
connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
assignedRank.toByteBuffer(worldSize)
)))
// reading should be suspended upon transitioning to BuildingLinkMap
connProbe.expectMsg(Tcp.SuspendReading)
// state should transition with according state data changes.
fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
connProbe.expectMsg(Tcp.ResumeReading)
// the handler should be able to handle spill-over data, and stash it until state transition.
trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
}
}