diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 68b887b23..a7c802dc1 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -62,14 +62,28 @@ object XGBoost extends Serializable { require(tracker.start(), "FAULT: Failed to start tracker") boosters = buildDistributedBoosters(trainingData, configMap, tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) - // force the job - boosters.foreachPartition(_ => ()) - val booster = boosters.first() + @volatile var booster: Booster = null + val sparkJobThread = new Thread() { + override def run() { + // force the job + boosters.foreachPartition(_ => ()) + } + } + sparkJobThread.start() val returnVal = tracker.waitFor() logger.info(s"Rabit returns with exit code $returnVal") if (returnVal == 0) { + booster = boosters.first() Some(booster) } else { + try { + if (sparkJobThread.isAlive) { + sparkJobThread.interrupt() + } + } catch { + case ie: InterruptedException => + logger.info("spark job thread is interrupted") + } None } }