[jvm-packages] Performance consideration and Alignment input parameter of repartition function (#4049)

This commit is contained in:
KyleLi1985 2019-01-08 00:38:05 +08:00 committed by Nan Zhu
parent 773ddbcfcb
commit dade7c3aff

View File

@ -191,8 +191,7 @@ object XGBoost extends Serializable {
private def coPartitionNoGroupSets( private def coPartitionNoGroupSets(
trainingData: RDD[XGBLabeledPoint], trainingData: RDD[XGBLabeledPoint],
evalSets: Map[String, RDD[XGBLabeledPoint]], evalSets: Map[String, RDD[XGBLabeledPoint]],
params: Map[String, Any]) = { nWorkers: Int) = {
val nWorkers = params("num_workers").asInstanceOf[Int]
// eval_sets is supposed to be set by the caller of [[trainDistributed]] // eval_sets is supposed to be set by the caller of [[trainDistributed]]
val allDatasets = Map("train" -> trainingData) ++ evalSets val allDatasets = Map("train" -> trainingData) ++ evalSets
val repartitionedDatasets = allDatasets.map{case (name, rdd) => val repartitionedDatasets = allDatasets.map{case (name, rdd) =>
@ -314,7 +313,7 @@ object XGBoost extends Serializable {
obj, eval, prevBooster) obj, eval, prevBooster)
}).cache() }).cache()
} else { } else {
coPartitionNoGroupSets(partitionedData, evalSetsMap, params).mapPartitions { coPartitionNoGroupSets(partitionedData, evalSetsMap, nWorkers).mapPartitions {
nameAndLabeledPointSets => nameAndLabeledPointSets =>
val watches = Watches.buildWatches( val watches = Watches.buildWatches(
nameAndLabeledPointSets.map { nameAndLabeledPointSets.map {
@ -344,7 +343,7 @@ object XGBoost extends Serializable {
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster) buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
}).cache() }).cache()
} else { } else {
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, params).mapPartitions( coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, nWorkers).mapPartitions(
labeledPointGroupSets => { labeledPointGroupSets => {
val watches = Watches.buildWatchesWithGroup( val watches = Watches.buildWatchesWithGroup(
labeledPointGroupSets.map { labeledPointGroupSets.map {
@ -452,8 +451,7 @@ object XGBoost extends Serializable {
private def coPartitionGroupSets( private def coPartitionGroupSets(
aggedTrainingSet: RDD[Array[XGBLabeledPoint]], aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
evalSets: Map[String, RDD[XGBLabeledPoint]], evalSets: Map[String, RDD[XGBLabeledPoint]],
params: Map[String, Any]): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = { nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
val nWorkers = params("num_workers").asInstanceOf[Int]
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map { val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
case (name, rdd) => { case (name, rdd) => {
val aggedRdd = aggByGroupInfo(rdd) val aggedRdd = aggByGroupInfo(rdd)