From dade7c3aff1044f83339fd1393f78501c2fd5841 Mon Sep 17 00:00:00 2001 From: KyleLi1985 <40689156+KyleLi1985@users.noreply.github.com> Date: Tue, 8 Jan 2019 00:38:05 +0800 Subject: [PATCH] [jvm-packages] Performance consideration and Alignment input parameter of repartition function (#4049) --- .../scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index c93f14353..4fe82b271 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -191,8 +191,7 @@ object XGBoost extends Serializable { private def coPartitionNoGroupSets( trainingData: RDD[XGBLabeledPoint], evalSets: Map[String, RDD[XGBLabeledPoint]], - params: Map[String, Any]) = { - val nWorkers = params("num_workers").asInstanceOf[Int] + nWorkers: Int) = { // eval_sets is supposed to be set by the caller of [[trainDistributed]] val allDatasets = Map("train" -> trainingData) ++ evalSets val repartitionedDatasets = allDatasets.map{case (name, rdd) => @@ -314,7 +313,7 @@ object XGBoost extends Serializable { obj, eval, prevBooster) }).cache() } else { - coPartitionNoGroupSets(partitionedData, evalSetsMap, params).mapPartitions { + coPartitionNoGroupSets(partitionedData, evalSetsMap, nWorkers).mapPartitions { nameAndLabeledPointSets => val watches = Watches.buildWatches( nameAndLabeledPointSets.map { @@ -344,7 +343,7 @@ object XGBoost extends Serializable { buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster) }).cache() } else { - coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, params).mapPartitions( + coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, nWorkers).mapPartitions( labeledPointGroupSets => { val watches = Watches.buildWatchesWithGroup( labeledPointGroupSets.map { @@ -452,8 +451,7 @@ object XGBoost extends Serializable { private def coPartitionGroupSets( aggedTrainingSet: RDD[Array[XGBLabeledPoint]], evalSets: Map[String, RDD[XGBLabeledPoint]], - params: Map[String, Any]): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = { - val nWorkers = params("num_workers").asInstanceOf[Int] + nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = { val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map { case (name, rdd) => { val aggedRdd = aggByGroupInfo(rdd)