diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 99a7495b0..8e730667d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -283,22 +283,26 @@ object XGBoost extends Serializable { "instance of TrackerConf.") } val tracker = startTracker(nWorkers, trackerConf) - val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext) - val boosters = buildDistributedBoosters(trainingData, overridedConfMap, - tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing) - val sparkJobThread = new Thread() { - override def run() { - // force the job - boosters.foreachPartition(() => _) + try { + val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext) + val boosters = buildDistributedBoosters(trainingData, overridedConfMap, + tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing) + val sparkJobThread = new Thread() { + override def run() { + // force the job + boosters.foreachPartition(() => _) + } } + sparkJobThread.setUncaughtExceptionHandler(tracker) + sparkJobThread.start() + val isClsTask = isClassificationTask(params) + val trackerReturnVal = tracker.waitFor(0L) + logger.info(s"Rabit returns with exit code $trackerReturnVal") + postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread, + isClsTask) + } finally { + tracker.stop() } - sparkJobThread.setUncaughtExceptionHandler(tracker) - sparkJobThread.start() - val isClsTask = isClassificationTask(params) - val trackerReturnVal = tracker.waitFor(0L) - logger.info(s"Rabit returns with exit code $trackerReturnVal") - postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread, - isClsTask) } private def postTrackerReturnProcessing( diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java index 2a2fcd423..984fb80e6 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit; * brokers connections between workers. */ public interface IRabitTracker extends Thread.UncaughtExceptionHandler { - public enum TrackerStatus { + enum TrackerStatus { SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3); private int statusCode; @@ -38,6 +38,7 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler { Map getWorkerEnvs(); boolean start(long workerConnectionTimeout); + void stop(); // taskExecutionTimeout has no effect in current version of XGBoost. int waitFor(long taskExecutionTimeout); } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index d2008cd7f..888d501db 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -139,7 +139,7 @@ public class RabitTracker implements IRabitTracker { } } - private void stop() { + public void stop() { if (trackerProcess.get() != null) { trackerProcess.get().destroy(); } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala index d6ca42e75..00cef158d 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala @@ -127,6 +127,12 @@ private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None, } } + def stop(): Unit = { + if (!system.isTerminated) { + system.shutdown() + } + } + /** * Get a Map of necessary environment variables to initiate Rabit workers. *