diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 43f602df6..db6bc8a98 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -161,10 +161,5 @@ 2.2.6 test - - com.typesafe - config - 1.2.1 - diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8b0d0a71e..ea7ba8563 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -16,10 +16,10 @@ package ml.dmlc.xgboost4j.scala.spark -import scala.collection.immutable.HashMap +import scala.collection.mutable -import com.typesafe.config.Config -import org.apache.spark.{TaskContext, SparkContext} +import org.apache.commons.logging.LogFactory +import org.apache.spark.TaskContext import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -28,6 +28,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} object XGBoost extends Serializable { + var boosters: RDD[Booster] = null + private val logger = LogFactory.getLog("XGBoostSpark") + implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = { new XGBoostModel(booster) } @@ -37,42 +40,33 @@ object XGBoost extends Serializable { xgBoostConfMap: Map[String, AnyRef], numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { import DataUtils._ - val sc = trainingData.sparkContext - val tracker = new RabitTracker(numWorkers) - if (tracker.start()) { - trainingData.repartition(numWorkers).mapPartitions { - trainingSamples => - Rabit.init(new java.util.HashMap[String, String]() { - put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) - }) - val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) - val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, - watches = new HashMap[String, DMatrix], obj, eval) - Rabit.shutdown() - Iterator(booster) - }.cache() - } else { - null - } + trainingData.repartition(numWorkers).mapPartitions { + trainingSamples => + Rabit.init(new java.util.HashMap[String, String]() { + put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) + }) + val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) + val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, + watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval) + Rabit.shutdown() + Iterator(booster) + }.cache() } - def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null, - eval: EvalTrait = null): Option[XGBoostModel] = { - import DataUtils._ - val numWorkers = config.getInt("numWorkers") - val round = config.getInt("round") + def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int, + obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = { + val numWorkers = trainingData.partitions.length val sc = trainingData.sparkContext val tracker = new RabitTracker(numWorkers) - if (tracker.start()) { - // TODO: build configuration map from config - val xgBoostConfigMap = new HashMap[String, AnyRef]() - val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round, - obj, eval) - // force the job - sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) - tracker.waitFor() - // TODO: how to choose best model - Some(boosters.first()) + require(tracker.start(), "FAULT: Failed to start tracker") + boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval) + // force the job + sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) + val booster = boosters.first() + val returnVal = tracker.waitFor() + logger.info(s"Rabit returns with exit code $returnVal") + if (returnVal == 0) { + Some(booster) } else { None } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index 762cff7bf..a5768d6cd 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -134,15 +134,18 @@ public class RabitTracker { } } - public void waitFor() { + public int waitFor() { try { trackerProcess.get().waitFor(); - logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue()); + int returnVal = trackerProcess.get().exitValue(); + logger.info("Tracker Process ends with exit code " + returnVal); stop(); + return returnVal; } catch (InterruptedException e) { // we should not get here as RabitTracker is accessed in the main thread e.printStackTrace(); logger.error("the RabitTracker thread is terminated unexpectedly"); + return 1; } } }