[jvm-packages] Scala implementation of the Rabit tracker. (#1612)

* [jvm-packages] Scala implementation of the Rabit tracker.

A Scala implementation of RabitTracker that is interface-interchangable with the
Java implementation, ported from `tracker.py` in the
[dmlc-core project](https://github.com/dmlc/dmlc-core).

* [jvm-packages] Updated Akka dependency in pom.xml.

* Refactored the RabitTracker directory structure.

* Fixed premature stopping of connection handler.

Added a new finite state "AwaitingPortNumber" to explicitly wait for the
worker to send the port, and close the connection. Stopping the actor
prematurely sends a TCP RST to the worker, causing the worker to crash
on AssertionError.

* Added interface IRabitTracker so that user can switch implementations.

* Default timeout duration changes.

* Dependency for Akka tests.

* Removed the main function of RabitTracker.

* A skeleton for testing Akka-based Rabit tracker.

* waitFor() in RabitTracker no longer throws exceptions.

* Completed unit test for the 'start' command of Rabit tracker.

* Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.)

* Fixed the default timeout duration.

* Use Java container to avoid serialization issues due to intermediate wrappers.

* Added tests for Allreduce/model training using Scala Rabit tracker.

* Added spill-over unit test for the Scala Rabit tracker.

* Fixed a typo.

* Overhaul of RabitTracker interface per code review.

  - Removed methods start() waitFor() (no arguments) from IRabitTracker.
  - The timeout in start(timeout) is now worker connection timeout, as tcp
    socket binding timeout is less intuitive.
  - Dropped time unit from start(...) and waitFor(...) methods; the default
    time unit is millisecond.
  - Moved random port number generation into the RabitTrackerHandler.
  - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit.

* More code refactoring and comments.

* Unified timeout constants. Readable tracker status code.

* Add comments to indicate that allReduce is for tests only. Removed all other variants.

* Removed unused imports.

* Simplified signatures of training methods.

 - Moved TrackerConf into parameter map.
 - Changed GeneralParams so that TrackerConf becomes a standalone parameter.
 - Updated test cases accordingly.

* Changed monitoring strategies.

* Reverted monitoring changes.

* Update test case for Rabit AllReduce.

* Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers.

* More comprehensive test cases for exception handling and worker connection timeout.

* Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case.

* Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code.

* Reverted scalastyle-config changes.

* Visibility scope change. Interface tweaks.

* Use match pattern to handle tracker_conf parameter.

* Minor clarification in JNI code.

* Clearer intent in match pattern to suppress warnings.

* Removed Future from constructor. Block in start() and waitFor() instead.

* Revert inadvertent comment changes.

* Removed debugging information.

* Updated test cases that are a bit finicky.

* Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness.
This commit is contained in:
Xin Yin
2016-12-07 09:35:42 -05:00
committed by Nan Zhu
parent 7078c41dad
commit e7fbc8591f
19 changed files with 1910 additions and 25 deletions

View File

@@ -0,0 +1,169 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.FunSuite
class RabitTrackerRobustnessSuite extends FunSuite with Utils {
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
that should be avoided.
*/
val sparkConf = new SparkConf().setMaster("local[*]")
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
implicit val sparkContext = new SparkContext(sparkConf)
sparkContext.setLogLevel("ERROR")
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new PyRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
/*
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
last created worker. A cascading event chain will be triggered once the RuntimeException is
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
The Java RabitTracker class reacts to exceptions by killing the spawned process running
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
it is killed, the resulted connection failure will trigger the Rabit worker to call
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
running in separate containers. However, as unit tests are run in Spark local mode, in which
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
ultimately kills the entire process, which also happens to host the Spark driver, causing
the entire Spark application to crash.
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
the exception is thrown at last, ideally after all worker connections have been established.
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
process to ensure that pending worker connections are handled.
*/
val dummyTasks = rdd.mapPartitions { iter =>
Rabit.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Rabit.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0) != 0)
sparkContext.stop()
}
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
val sparkConf = new SparkConf().setMaster("local[*]")
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
implicit val sparkContext = new SparkContext(sparkConf)
sparkContext.setLogLevel("ERROR")
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
val dummyTasks = rdd.mapPartitions { iter =>
Rabit.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Rabit.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
sparkContext.stop()
}
test("test Scala RabitTracker's workerConnectionTimeout") {
val sparkConf = new SparkConf().setMaster("local[*]")
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
implicit val sparkContext = new SparkContext(sparkConf)
sparkContext.setLogLevel("ERROR")
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(500)
val trackerEnvs = tracker.getWorkerEnvs
val dummyTasks = rdd.mapPartitions { iter =>
val index = iter.next()
// simulate that the first worker cannot connect to tracker due to network issues.
if (index != 1) {
Rabit.init(trackerEnvs)
Thread.sleep(1000)
Rabit.shutdown()
}
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
// should fail due to connection timeout
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
sparkContext.stop()
}
}

View File

@@ -17,18 +17,60 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
import scala.collection.mutable.ListBuffer
import scala.util.Random
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import scala.concurrent.duration._
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors}
import org.apache.spark.ml.linalg.{Vectors, Vector => SparkVector}
import org.apache.spark.rdd.RDD
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
val vectorLength = 100
val rdd = sc.parallelize(
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
val rawData = rdd.mapPartitions { iter =>
Iterator(iter.toArray)
}.collect()
val maxVec = (0 until vectorLength).toArray.map { j =>
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
}
val allReduceResults = rdd.mapPartitions { iter =>
Rabit.init(trackerEnvs)
val arr = iter.toArray
val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
Rabit.shutdown()
Iterator(results)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
allReduceResults.foreachPartition(() => _)
val byPartitionResults = allReduceResults.collect()
assert(byPartitionResults(0).length == vectorLength)
collectedAllReduceResults.put(byPartitionResults(0))
}
}
sparkThread.start()
assert(tracker.waitFor(0L) == 0)
sparkThread.join()
assert(collectedAllReduceResults.poll().sameElements(maxVec))
}
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = buildTrainingRDD(sc)
@@ -36,7 +78,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
new java.util.HashMap[String, String](),
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
val boosterCount = boosterRDD.count()
assert(boosterCount === 2)
@@ -59,6 +101,21 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
cleanExternalCache("XGBoostSuite")
}
test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError()
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
}
test("test with dense vectors containing missing value") {
def buildDenseRDD(): RDD[LabeledPoint] = {
val nrow = 100