[jvm-packages] Deterministically XGBoost training on exception (#2405)
Previously the code relied on the tracker process being terminated by the OS, which was not the case on Windows. Closes #2394
This commit is contained in:
@@ -283,22 +283,26 @@ object XGBoost extends Serializable {
|
||||
"instance of TrackerConf.")
|
||||
}
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boosters.foreachPartition(() => _)
|
||||
try {
|
||||
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boosters.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
||||
isClsTask)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
||||
isClsTask)
|
||||
}
|
||||
|
||||
private def postTrackerReturnProcessing(
|
||||
|
||||
Reference in New Issue
Block a user