[jvm-packages] Performance consideration and Alignment input parameter of repartition function (#4049)
This commit is contained in:
parent
773ddbcfcb
commit
dade7c3aff
@ -191,8 +191,7 @@ object XGBoost extends Serializable {
|
||||
private def coPartitionNoGroupSets(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||
params: Map[String, Any]) = {
|
||||
val nWorkers = params("num_workers").asInstanceOf[Int]
|
||||
nWorkers: Int) = {
|
||||
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
|
||||
val allDatasets = Map("train" -> trainingData) ++ evalSets
|
||||
val repartitionedDatasets = allDatasets.map{case (name, rdd) =>
|
||||
@ -314,7 +313,7 @@ object XGBoost extends Serializable {
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionNoGroupSets(partitionedData, evalSetsMap, params).mapPartitions {
|
||||
coPartitionNoGroupSets(partitionedData, evalSetsMap, nWorkers).mapPartitions {
|
||||
nameAndLabeledPointSets =>
|
||||
val watches = Watches.buildWatches(
|
||||
nameAndLabeledPointSets.map {
|
||||
@ -344,7 +343,7 @@ object XGBoost extends Serializable {
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, params).mapPartitions(
|
||||
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, nWorkers).mapPartitions(
|
||||
labeledPointGroupSets => {
|
||||
val watches = Watches.buildWatchesWithGroup(
|
||||
labeledPointGroupSets.map {
|
||||
@ -452,8 +451,7 @@ object XGBoost extends Serializable {
|
||||
private def coPartitionGroupSets(
|
||||
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
|
||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||
params: Map[String, Any]): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
|
||||
val nWorkers = params("num_workers").asInstanceOf[Int]
|
||||
nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
|
||||
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
|
||||
case (name, rdd) => {
|
||||
val aggedRdd = aggByGroupInfo(rdd)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user