[jvm-packages] Deterministically XGBoost training on exception (#2405)
Previously the code relied on the tracker process being terminated by the OS, which was not the case on Windows. Closes #2394
This commit is contained in:
parent
34dfe2f6de
commit
0db37c05bd
@ -283,6 +283,7 @@ object XGBoost extends Serializable {
|
|||||||
"instance of TrackerConf.")
|
"instance of TrackerConf.")
|
||||||
}
|
}
|
||||||
val tracker = startTracker(nWorkers, trackerConf)
|
val tracker = startTracker(nWorkers, trackerConf)
|
||||||
|
try {
|
||||||
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
||||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||||
@ -299,6 +300,9 @@ object XGBoost extends Serializable {
|
|||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
||||||
isClsTask)
|
isClsTask)
|
||||||
|
} finally {
|
||||||
|
tracker.stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def postTrackerReturnProcessing(
|
private def postTrackerReturnProcessing(
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit;
|
|||||||
* brokers connections between workers.
|
* brokers connections between workers.
|
||||||
*/
|
*/
|
||||||
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||||
public enum TrackerStatus {
|
enum TrackerStatus {
|
||||||
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
||||||
|
|
||||||
private int statusCode;
|
private int statusCode;
|
||||||
@ -38,6 +38,7 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
|||||||
|
|
||||||
Map<String, String> getWorkerEnvs();
|
Map<String, String> getWorkerEnvs();
|
||||||
boolean start(long workerConnectionTimeout);
|
boolean start(long workerConnectionTimeout);
|
||||||
|
void stop();
|
||||||
// taskExecutionTimeout has no effect in current version of XGBoost.
|
// taskExecutionTimeout has no effect in current version of XGBoost.
|
||||||
int waitFor(long taskExecutionTimeout);
|
int waitFor(long taskExecutionTimeout);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -139,7 +139,7 @@ public class RabitTracker implements IRabitTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void stop() {
|
public void stop() {
|
||||||
if (trackerProcess.get() != null) {
|
if (trackerProcess.get() != null) {
|
||||||
trackerProcess.get().destroy();
|
trackerProcess.get().destroy();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -127,6 +127,12 @@ private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def stop(): Unit = {
|
||||||
|
if (!system.isTerminated) {
|
||||||
|
system.shutdown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get a Map of necessary environment variables to initiate Rabit workers.
|
* Get a Map of necessary environment variables to initiate Rabit workers.
|
||||||
*
|
*
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user