[jvm-packages] fix issue when spark job execution thread cannot return before we execute first() (#3758)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* sparjJobThread

* update

* fix issue when spark job execution thread cannot return before we execute first()
This commit is contained in:
Nan Zhu 2018-10-05 22:20:50 -07:00 committed by GitHub
parent 9e73087324
commit 785094db53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -258,20 +258,20 @@ object XGBoost extends Serializable {
case true => {
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(params,
val watches = Watches.buildWatchesWithGroup(overriddenParams,
removeMissingValuesWithGroup(labeledPointGroups, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
obj, eval, prevBooster)
}).cache()
}
case false => {
val partitionedData = repartitionForTraining(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(params,
val watches = Watches.buildWatches(overriddenParams,
removeMissingValues(labeledPoints, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
obj, eval, prevBooster)
}).cache()
}
@ -341,6 +341,9 @@ object XGBoost extends Serializable {
// Copies of the final booster and the corresponding metrics
// reside in each partition of the `distributedBoostersAndMetrics`.
// Any of them can be used to create the model.
// it's safe to block here forever, as the tracker has returned successfully, and the Spark
// job should have finished, there is no reason for the thread cannot return
sparkJobThread.join()
val (booster, metrics) = distributedBoostersAndMetrics.first()
distributedBoostersAndMetrics.unpersist(false)
(booster, metrics)