diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 7a35b85ec..4bbef3cde 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext @@ -38,13 +39,13 @@ object XGBoost extends Serializable { private[spark] def buildDistributedBoosters( trainingData: RDD[LabeledPoint], xgBoostConfMap: Map[String, AnyRef], + rabitEnv: mutable.Map[String, String], numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { import DataUtils._ trainingData.repartition(numWorkers).mapPartitions { trainingSamples => - Rabit.init(new java.util.HashMap[String, String]() { - put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) - }) + rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) + Rabit.init(rabitEnv.asJava) val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval) @@ -59,7 +60,8 @@ object XGBoost extends Serializable { val sc = trainingData.sparkContext val tracker = new RabitTracker(numWorkers) require(tracker.start(), "FAULT: Failed to start tracker") - boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval) + boosters = buildDistributedBoosters(trainingData, configMap, + tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) // force the job boosters.foreachPartition(_ => ()) println("=====finished training=====")