[jvm-packages] Deterministically XGBoost training on exception (#2405)

Previously the code relied on the tracker process being terminated
by the OS, which was not the case on Windows.

Closes #2394
This commit is contained in:
Sergei Lebedev
2017-06-13 05:19:28 +02:00
committed by Nan Zhu
parent 34dfe2f6de
commit 0db37c05bd
4 changed files with 27 additions and 16 deletions

View File

@@ -283,22 +283,26 @@ object XGBoost extends Serializable {
"instance of TrackerConf.")
}
val tracker = startTracker(nWorkers, trackerConf)
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
val sparkJobThread = new Thread() {
override def run() {
// force the job
boosters.foreachPartition(() => _)
try {
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
val sparkJobThread = new Thread() {
override def run() {
// force the job
boosters.foreachPartition(() => _)
}
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
isClsTask)
} finally {
tracker.stop()
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
isClsTask)
}
private def postTrackerReturnProcessing(