From 04c99683c3b4f27afb8a877c67721ccd6715881a Mon Sep 17 00:00:00 2001 From: jinmfeng001 <102719116+jinmfeng001@users.noreply.github.com> Date: Thu, 3 Aug 2023 23:40:04 +0800 Subject: [PATCH] Change training stage from ResultStage to ShuffleMapStage (#9423) --- .../main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 2f1f261fb..7bb245035 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -407,7 +407,10 @@ object XGBoost extends Serializable { }} - val (booster, metrics) = boostersAndMetrics.collect()(0) + // The repartition step is to make training stage as ShuffleMapStage, so that when one + // of the training task fails the training stage can retry. ResultStage won't retry when + // it fails. + val (booster, metrics) = boostersAndMetrics.repartition(1).collect()(0) val trackerReturnVal = tracker.waitFor(0L) logger.info(s"Rabit returns with exit code $trackerReturnVal") if (trackerReturnVal != 0) {