[jvm-packages] fix issue when spark job execution thread cannot return before we execute first() (#3758)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * sparjJobThread * update * fix issue when spark job execution thread cannot return before we execute first()
This commit is contained in:
parent
9e73087324
commit
785094db53
@ -258,20 +258,20 @@ object XGBoost extends Serializable {
|
|||||||
case true => {
|
case true => {
|
||||||
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||||
partitionedData.mapPartitions(labeledPointGroups => {
|
partitionedData.mapPartitions(labeledPointGroups => {
|
||||||
val watches = Watches.buildWatchesWithGroup(params,
|
val watches = Watches.buildWatchesWithGroup(overriddenParams,
|
||||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
|
||||||
obj, eval, prevBooster)
|
obj, eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
}
|
}
|
||||||
case false => {
|
case false => {
|
||||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||||
partitionedData.mapPartitions(labeledPoints => {
|
partitionedData.mapPartitions(labeledPoints => {
|
||||||
val watches = Watches.buildWatches(params,
|
val watches = Watches.buildWatches(overriddenParams,
|
||||||
removeMissingValues(labeledPoints, missing),
|
removeMissingValues(labeledPoints, missing),
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
|
||||||
obj, eval, prevBooster)
|
obj, eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
}
|
}
|
||||||
@ -341,6 +341,9 @@ object XGBoost extends Serializable {
|
|||||||
// Copies of the final booster and the corresponding metrics
|
// Copies of the final booster and the corresponding metrics
|
||||||
// reside in each partition of the `distributedBoostersAndMetrics`.
|
// reside in each partition of the `distributedBoostersAndMetrics`.
|
||||||
// Any of them can be used to create the model.
|
// Any of them can be used to create the model.
|
||||||
|
// it's safe to block here forever, as the tracker has returned successfully, and the Spark
|
||||||
|
// job should have finished, there is no reason for the thread cannot return
|
||||||
|
sparkJobThread.join()
|
||||||
val (booster, metrics) = distributedBoostersAndMetrics.first()
|
val (booster, metrics) = distributedBoostersAndMetrics.first()
|
||||||
distributedBoostersAndMetrics.unpersist(false)
|
distributedBoostersAndMetrics.unpersist(false)
|
||||||
(booster, metrics)
|
(booster, metrics)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user