diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 43f602df6..db6bc8a98 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -161,10 +161,5 @@
2.2.6
test
-
- com.typesafe
- config
- 1.2.1
-
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index 8b0d0a71e..ea7ba8563 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -16,10 +16,10 @@
package ml.dmlc.xgboost4j.scala.spark
-import scala.collection.immutable.HashMap
+import scala.collection.mutable
-import com.typesafe.config.Config
-import org.apache.spark.{TaskContext, SparkContext}
+import org.apache.commons.logging.LogFactory
+import org.apache.spark.TaskContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -28,6 +28,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost extends Serializable {
+ var boosters: RDD[Booster] = null
+ private val logger = LogFactory.getLog("XGBoostSpark")
+
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster)
}
@@ -37,42 +40,33 @@ object XGBoost extends Serializable {
xgBoostConfMap: Map[String, AnyRef],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
import DataUtils._
- val sc = trainingData.sparkContext
- val tracker = new RabitTracker(numWorkers)
- if (tracker.start()) {
- trainingData.repartition(numWorkers).mapPartitions {
- trainingSamples =>
- Rabit.init(new java.util.HashMap[String, String]() {
- put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
- })
- val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
- val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
- watches = new HashMap[String, DMatrix], obj, eval)
- Rabit.shutdown()
- Iterator(booster)
- }.cache()
- } else {
- null
- }
+ trainingData.repartition(numWorkers).mapPartitions {
+ trainingSamples =>
+ Rabit.init(new java.util.HashMap[String, String]() {
+ put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
+ })
+ val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
+ val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
+ watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval)
+ Rabit.shutdown()
+ Iterator(booster)
+ }.cache()
}
- def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
- eval: EvalTrait = null): Option[XGBoostModel] = {
- import DataUtils._
- val numWorkers = config.getInt("numWorkers")
- val round = config.getInt("round")
+ def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
+ obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
+ val numWorkers = trainingData.partitions.length
val sc = trainingData.sparkContext
val tracker = new RabitTracker(numWorkers)
- if (tracker.start()) {
- // TODO: build configuration map from config
- val xgBoostConfigMap = new HashMap[String, AnyRef]()
- val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
- obj, eval)
- // force the job
- sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
- tracker.waitFor()
- // TODO: how to choose best model
- Some(boosters.first())
+ require(tracker.start(), "FAULT: Failed to start tracker")
+ boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval)
+ // force the job
+ sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
+ val booster = boosters.first()
+ val returnVal = tracker.waitFor()
+ logger.info(s"Rabit returns with exit code $returnVal")
+ if (returnVal == 0) {
+ Some(booster)
} else {
None
}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
index 762cff7bf..a5768d6cd 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
@@ -134,15 +134,18 @@ public class RabitTracker {
}
}
- public void waitFor() {
+ public int waitFor() {
try {
trackerProcess.get().waitFor();
- logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
+ int returnVal = trackerProcess.get().exitValue();
+ logger.info("Tracker Process ends with exit code " + returnVal);
stop();
+ return returnVal;
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
+ return 1;
}
}
}