[jvm-packages] Deterministically XGBoost training on exception (#2405)

Previously the code relied on the tracker process being terminated
by the OS, which was not the case on Windows.

Closes #2394
This commit is contained in:
Sergei Lebedev 2017-06-13 05:19:28 +02:00 committed by Nan Zhu
parent 34dfe2f6de
commit 0db37c05bd
4 changed files with 27 additions and 16 deletions

View File

@ -283,22 +283,26 @@ object XGBoost extends Serializable {
"instance of TrackerConf.") "instance of TrackerConf.")
} }
val tracker = startTracker(nWorkers, trackerConf) val tracker = startTracker(nWorkers, trackerConf)
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext) try {
val boosters = buildDistributedBoosters(trainingData, overridedConfMap, val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing) val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
val sparkJobThread = new Thread() { tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
override def run() { val sparkJobThread = new Thread() {
// force the job override def run() {
boosters.foreachPartition(() => _) // force the job
boosters.foreachPartition(() => _)
}
} }
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
isClsTask)
} finally {
tracker.stop()
} }
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
isClsTask)
} }
private def postTrackerReturnProcessing( private def postTrackerReturnProcessing(

View File

@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit;
* brokers connections between workers. * brokers connections between workers.
*/ */
public interface IRabitTracker extends Thread.UncaughtExceptionHandler { public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
public enum TrackerStatus { enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3); SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
private int statusCode; private int statusCode;
@ -38,6 +38,7 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
Map<String, String> getWorkerEnvs(); Map<String, String> getWorkerEnvs();
boolean start(long workerConnectionTimeout); boolean start(long workerConnectionTimeout);
void stop();
// taskExecutionTimeout has no effect in current version of XGBoost. // taskExecutionTimeout has no effect in current version of XGBoost.
int waitFor(long taskExecutionTimeout); int waitFor(long taskExecutionTimeout);
} }

View File

@ -139,7 +139,7 @@ public class RabitTracker implements IRabitTracker {
} }
} }
private void stop() { public void stop() {
if (trackerProcess.get() != null) { if (trackerProcess.get() != null) {
trackerProcess.get().destroy(); trackerProcess.get().destroy();
} }

View File

@ -127,6 +127,12 @@ private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
} }
} }
def stop(): Unit = {
if (!system.isTerminated) {
system.shutdown()
}
}
/** /**
* Get a Map of necessary environment variables to initiate Rabit workers. * Get a Map of necessary environment variables to initiate Rabit workers.
* *