[jvm-packages] Performance consideration and Alignment input parameter of repartition function (#4049)
This commit is contained in:
parent
773ddbcfcb
commit
dade7c3aff
@ -191,8 +191,7 @@ object XGBoost extends Serializable {
|
|||||||
private def coPartitionNoGroupSets(
|
private def coPartitionNoGroupSets(
|
||||||
trainingData: RDD[XGBLabeledPoint],
|
trainingData: RDD[XGBLabeledPoint],
|
||||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||||
params: Map[String, Any]) = {
|
nWorkers: Int) = {
|
||||||
val nWorkers = params("num_workers").asInstanceOf[Int]
|
|
||||||
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
|
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
|
||||||
val allDatasets = Map("train" -> trainingData) ++ evalSets
|
val allDatasets = Map("train" -> trainingData) ++ evalSets
|
||||||
val repartitionedDatasets = allDatasets.map{case (name, rdd) =>
|
val repartitionedDatasets = allDatasets.map{case (name, rdd) =>
|
||||||
@ -314,7 +313,7 @@ object XGBoost extends Serializable {
|
|||||||
obj, eval, prevBooster)
|
obj, eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionNoGroupSets(partitionedData, evalSetsMap, params).mapPartitions {
|
coPartitionNoGroupSets(partitionedData, evalSetsMap, nWorkers).mapPartitions {
|
||||||
nameAndLabeledPointSets =>
|
nameAndLabeledPointSets =>
|
||||||
val watches = Watches.buildWatches(
|
val watches = Watches.buildWatches(
|
||||||
nameAndLabeledPointSets.map {
|
nameAndLabeledPointSets.map {
|
||||||
@ -344,7 +343,7 @@ object XGBoost extends Serializable {
|
|||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, params).mapPartitions(
|
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, nWorkers).mapPartitions(
|
||||||
labeledPointGroupSets => {
|
labeledPointGroupSets => {
|
||||||
val watches = Watches.buildWatchesWithGroup(
|
val watches = Watches.buildWatchesWithGroup(
|
||||||
labeledPointGroupSets.map {
|
labeledPointGroupSets.map {
|
||||||
@ -452,8 +451,7 @@ object XGBoost extends Serializable {
|
|||||||
private def coPartitionGroupSets(
|
private def coPartitionGroupSets(
|
||||||
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
|
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
|
||||||
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
evalSets: Map[String, RDD[XGBLabeledPoint]],
|
||||||
params: Map[String, Any]): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
|
nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
|
||||||
val nWorkers = params("num_workers").asInstanceOf[Int]
|
|
||||||
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
|
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
|
||||||
case (name, rdd) => {
|
case (name, rdd) => {
|
||||||
val aggedRdd = aggByGroupInfo(rdd)
|
val aggedRdd = aggByGroupInfo(rdd)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user