diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index d31810098..038a65889 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -258,20 +258,20 @@ object XGBoost extends Serializable { case true => { val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers) partitionedData.mapPartitions(labeledPointGroups => { - val watches = Watches.buildWatchesWithGroup(params, + val watches = Watches.buildWatchesWithGroup(overriddenParams, removeMissingValuesWithGroup(labeledPointGroups, missing), getCacheDirName(useExternalMemory)) - buildDistributedBooster(watches, params, rabitEnv, checkpointRound, + buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound, obj, eval, prevBooster) }).cache() } case false => { val partitionedData = repartitionForTraining(trainingData, nWorkers) partitionedData.mapPartitions(labeledPoints => { - val watches = Watches.buildWatches(params, + val watches = Watches.buildWatches(overriddenParams, removeMissingValues(labeledPoints, missing), getCacheDirName(useExternalMemory)) - buildDistributedBooster(watches, params, rabitEnv, checkpointRound, + buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound, obj, eval, prevBooster) }).cache() } @@ -341,6 +341,9 @@ object XGBoost extends Serializable { // Copies of the final booster and the corresponding metrics // reside in each partition of the `distributedBoostersAndMetrics`. // Any of them can be used to create the model. + // it's safe to block here forever, as the tracker has returned successfully, and the Spark + // job should have finished, there is no reason for the thread cannot return + sparkJobThread.join() val (booster, metrics) = distributedBoostersAndMetrics.first() distributedBoostersAndMetrics.unpersist(false) (booster, metrics)