[jvm-packages] fix issue when spark job execution thread cannot return before we execute first() (#3758)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* sparjJobThread

* update

* fix issue when spark job execution thread cannot return before we execute first()
This commit is contained in:
Nan Zhu 2018-10-05 22:20:50 -07:00 committed by GitHub
parent 9e73087324
commit 785094db53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -258,20 +258,20 @@ object XGBoost extends Serializable {
case true => { case true => {
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers) val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPointGroups => { partitionedData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(params, val watches = Watches.buildWatchesWithGroup(overriddenParams,
removeMissingValuesWithGroup(labeledPointGroups, missing), removeMissingValuesWithGroup(labeledPointGroups, missing),
getCacheDirName(useExternalMemory)) getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
obj, eval, prevBooster) obj, eval, prevBooster)
}).cache() }).cache()
} }
case false => { case false => {
val partitionedData = repartitionForTraining(trainingData, nWorkers) val partitionedData = repartitionForTraining(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPoints => { partitionedData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(params, val watches = Watches.buildWatches(overriddenParams,
removeMissingValues(labeledPoints, missing), removeMissingValues(labeledPoints, missing),
getCacheDirName(useExternalMemory)) getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
obj, eval, prevBooster) obj, eval, prevBooster)
}).cache() }).cache()
} }
@ -341,6 +341,9 @@ object XGBoost extends Serializable {
// Copies of the final booster and the corresponding metrics // Copies of the final booster and the corresponding metrics
// reside in each partition of the `distributedBoostersAndMetrics`. // reside in each partition of the `distributedBoostersAndMetrics`.
// Any of them can be used to create the model. // Any of them can be used to create the model.
// it's safe to block here forever, as the tracker has returned successfully, and the Spark
// job should have finished, there is no reason for the thread cannot return
sparkJobThread.join()
val (booster, metrics) = distributedBoostersAndMetrics.first() val (booster, metrics) = distributedBoostersAndMetrics.first()
distributedBoostersAndMetrics.unpersist(false) distributedBoostersAndMetrics.unpersist(false)
(booster, metrics) (booster, metrics)