diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 2f1f261fb..7bb245035 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -407,7 +407,10 @@ object XGBoost extends Serializable { }} - val (booster, metrics) = boostersAndMetrics.collect()(0) + // The repartition step is to make training stage as ShuffleMapStage, so that when one + // of the training task fails the training stage can retry. ResultStage won't retry when + // it fails. + val (booster, metrics) = boostersAndMetrics.repartition(1).collect()(0) val trackerReturnVal = tracker.waitFor(0L) logger.info(s"Rabit returns with exit code $trackerReturnVal") if (trackerReturnVal != 0) {