[jvm-packages] Scala implementation of the Rabit tracker. (#1612)

* [jvm-packages] Scala implementation of the Rabit tracker.

A Scala implementation of RabitTracker that is interface-interchangable with the
Java implementation, ported from `tracker.py` in the
[dmlc-core project](https://github.com/dmlc/dmlc-core).

* [jvm-packages] Updated Akka dependency in pom.xml.

* Refactored the RabitTracker directory structure.

* Fixed premature stopping of connection handler.

Added a new finite state "AwaitingPortNumber" to explicitly wait for the
worker to send the port, and close the connection. Stopping the actor
prematurely sends a TCP RST to the worker, causing the worker to crash
on AssertionError.

* Added interface IRabitTracker so that user can switch implementations.

* Default timeout duration changes.

* Dependency for Akka tests.

* Removed the main function of RabitTracker.

* A skeleton for testing Akka-based Rabit tracker.

* waitFor() in RabitTracker no longer throws exceptions.

* Completed unit test for the 'start' command of Rabit tracker.

* Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.)

* Fixed the default timeout duration.

* Use Java container to avoid serialization issues due to intermediate wrappers.

* Added tests for Allreduce/model training using Scala Rabit tracker.

* Added spill-over unit test for the Scala Rabit tracker.

* Fixed a typo.

* Overhaul of RabitTracker interface per code review.

  - Removed methods start() waitFor() (no arguments) from IRabitTracker.
  - The timeout in start(timeout) is now worker connection timeout, as tcp
    socket binding timeout is less intuitive.
  - Dropped time unit from start(...) and waitFor(...) methods; the default
    time unit is millisecond.
  - Moved random port number generation into the RabitTrackerHandler.
  - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit.

* More code refactoring and comments.

* Unified timeout constants. Readable tracker status code.

* Add comments to indicate that allReduce is for tests only. Removed all other variants.

* Removed unused imports.

* Simplified signatures of training methods.

 - Moved TrackerConf into parameter map.
 - Changed GeneralParams so that TrackerConf becomes a standalone parameter.
 - Updated test cases accordingly.

* Changed monitoring strategies.

* Reverted monitoring changes.

* Update test case for Rabit AllReduce.

* Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers.

* More comprehensive test cases for exception handling and worker connection timeout.

* Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case.

* Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code.

* Reverted scalastyle-config changes.

* Visibility scope change. Interface tweaks.

* Use match pattern to handle tracker_conf parameter.

* Minor clarification in JNI code.

* Clearer intent in match pattern to suppress warnings.

* Removed Future from constructor. Block in start() and waitFor() instead.

* Revert inadvertent comment changes.

* Removed debugging information.

* Updated test cases that are a bit finicky.

* Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness.
This commit is contained in:
Xin Yin
2016-12-07 09:35:42 -05:00
committed by Nan Zhu
parent 7078c41dad
commit e7fbc8591f
19 changed files with 1910 additions and 25 deletions

View File

@@ -94,6 +94,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
long max_elem = cbatch.offset[cbatch.size];
cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0);
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
<< "batch.index.length must equal batch.offset.back()";
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
@@ -756,3 +757,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitAllreduce
* Signature: (Ljava/nio/ByteBuffer;III)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
(JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) {
void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf);
RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL);
return 0;
}