[jvm-packages] fix issue when spark job execution thread cannot return before we execute first() (#3758)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * sparjJobThread * update * fix issue when spark job execution thread cannot return before we execute first()
This commit is contained in:
parent
9e73087324
commit
785094db53
@ -258,20 +258,20 @@ object XGBoost extends Serializable {
|
||||
case true => {
|
||||
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(params,
|
||||
val watches = Watches.buildWatchesWithGroup(overriddenParams,
|
||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
case false => {
|
||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(params,
|
||||
val watches = Watches.buildWatches(overriddenParams,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
@ -341,6 +341,9 @@ object XGBoost extends Serializable {
|
||||
// Copies of the final booster and the corresponding metrics
|
||||
// reside in each partition of the `distributedBoostersAndMetrics`.
|
||||
// Any of them can be used to create the model.
|
||||
// it's safe to block here forever, as the tracker has returned successfully, and the Spark
|
||||
// job should have finished, there is no reason for the thread cannot return
|
||||
sparkJobThread.join()
|
||||
val (booster, metrics) = distributedBoostersAndMetrics.first()
|
||||
distributedBoostersAndMetrics.unpersist(false)
|
||||
(booster, metrics)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user