diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 43f602df6..db6bc8a98 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -161,10 +161,5 @@
2.2.6
test
-
- com.typesafe
- config
- 1.2.1
-
diff --git a/jvm-packages/xgboost4j-demo/pom.xml b/jvm-packages/xgboost4j-demo/pom.xml
index bef184adb..e076af63d 100644
--- a/jvm-packages/xgboost4j-demo/pom.xml
+++ b/jvm-packages/xgboost4j-demo/pom.xml
@@ -25,7 +25,7 @@
ml.dmlc
- xgboost4j
+ xgboost4j-spark
0.1
diff --git a/jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala b/jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala
new file mode 100644
index 000000000..8fd794423
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala
@@ -0,0 +1,74 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark.demo
+
+import java.io.File
+
+import scala.collection.mutable.ListBuffer
+import scala.io.Source
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.DenseVector
+import org.apache.spark.mllib.regression.LabeledPoint
+
+import ml.dmlc.xgboost4j.scala.DMatrix
+import ml.dmlc.xgboost4j.scala.spark.XGBoost
+
+
+object DistTrainWithSpark {
+
+ private def readFile(filePath: String): List[LabeledPoint] = {
+ val file = Source.fromFile(new File(filePath))
+ val sampleList = new ListBuffer[LabeledPoint]
+ for (sample <- file.getLines()) {
+ sampleList += fromSVMStringToLabeledPoint(sample)
+ }
+ sampleList.toList
+ }
+
+ private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
+ val labelAndFeatures = line.split(" ")
+ val label = labelAndFeatures(0).toInt
+ val features = labelAndFeatures.tail
+ val denseFeature = new Array[Double](129)
+ for (feature <- features) {
+ val idAndValue = feature.split(":")
+ denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
+ }
+ LabeledPoint(label, new DenseVector(denseFeature))
+ }
+
+ def main(args: Array[String]): Unit = {
+ import ml.dmlc.xgboost4j.scala.spark.DataUtils._
+ if (args.length != 4) {
+ println(
+ "usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
+ sys.exit(1)
+ }
+ val sc = new SparkContext()
+ val inputTrainPath = args(2)
+ val inputTestPath = args(3)
+ val trainingLabeledPoints = readFile(inputTrainPath)
+ val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
+ val testLabeledPoints = readFile(inputTestPath).iterator
+ val testMatrix = new DMatrix(testLabeledPoints, null)
+ val booster = XGBoost.train(trainRDD,
+ List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
+ "objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
+ booster.map(boosterInstance => boosterInstance.predict(testMatrix))
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala
index 12fb545c9..d61cb9fc1 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala
@@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
import ml.dmlc.xgboost4j.LabeledPoint
-private[spark] object DataUtils extends Serializable {
+object DataUtils extends Serializable {
+
+ implicit def fromSparkToXGBoostLabeledPointsAsJava(
+ sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
+ fromSparkToXGBoostLabeledPoints(sps).asJava
+ }
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
- java.util.Iterator[LabeledPoint] = {
- (for (p <- sps) yield {
+ Iterator[LabeledPoint] = {
+ for (p <- sps) yield {
p.features match {
case denseFeature: DenseVector =>
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
@@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
sparseFeature.values.map(_.toFloat))
}
- }).asJava
- }
-
- private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
- (sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
- }
-
- private def fetchUpdateFromVector(feature: Vector) = feature match {
- case denseFeature: DenseVector =>
- fetchUpdateFromSparseVector(denseFeature.toSparse)
- case sparseFeature: SparseVector =>
- fetchUpdateFromSparseVector(sparseFeature)
+ }
}
}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index 8b0d0a71e..a7c802dc1 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -16,10 +16,11 @@
package ml.dmlc.xgboost4j.scala.spark
-import scala.collection.immutable.HashMap
+import scala.collection.mutable
+import scala.collection.JavaConverters._
-import com.typesafe.config.Config
-import org.apache.spark.{TaskContext, SparkContext}
+import org.apache.commons.logging.LogFactory
+import org.apache.spark.TaskContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -28,6 +29,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost extends Serializable {
+ var boosters: RDD[Booster] = null
+ private val logger = LogFactory.getLog("XGBoostSpark")
+
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster)
}
@@ -35,45 +39,51 @@ object XGBoost extends Serializable {
private[spark] def buildDistributedBoosters(
trainingData: RDD[LabeledPoint],
xgBoostConfMap: Map[String, AnyRef],
+ rabitEnv: mutable.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
import DataUtils._
- val sc = trainingData.sparkContext
- val tracker = new RabitTracker(numWorkers)
- if (tracker.start()) {
- trainingData.repartition(numWorkers).mapPartitions {
- trainingSamples =>
- Rabit.init(new java.util.HashMap[String, String]() {
- put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
- })
- val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
- val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
- watches = new HashMap[String, DMatrix], obj, eval)
- Rabit.shutdown()
- Iterator(booster)
- }.cache()
- } else {
- null
- }
+ trainingData.repartition(numWorkers).mapPartitions {
+ trainingSamples =>
+ rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
+ Rabit.init(rabitEnv.asJava)
+ val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
+ val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
+ watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval)
+ Rabit.shutdown()
+ Iterator(booster)
+ }.cache()
}
- def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
- eval: EvalTrait = null): Option[XGBoostModel] = {
- import DataUtils._
- val numWorkers = config.getInt("numWorkers")
- val round = config.getInt("round")
+ def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
+ obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
+ val numWorkers = trainingData.partitions.length
val sc = trainingData.sparkContext
val tracker = new RabitTracker(numWorkers)
- if (tracker.start()) {
- // TODO: build configuration map from config
- val xgBoostConfigMap = new HashMap[String, AnyRef]()
- val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
- obj, eval)
- // force the job
- sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
- tracker.waitFor()
- // TODO: how to choose best model
- Some(boosters.first())
+ require(tracker.start(), "FAULT: Failed to start tracker")
+ boosters = buildDistributedBoosters(trainingData, configMap,
+ tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
+ @volatile var booster: Booster = null
+ val sparkJobThread = new Thread() {
+ override def run() {
+ // force the job
+ boosters.foreachPartition(_ => ())
+ }
+ }
+ sparkJobThread.start()
+ val returnVal = tracker.waitFor()
+ logger.info(s"Rabit returns with exit code $returnVal")
+ if (returnVal == 0) {
+ booster = boosters.first()
+ Some(booster)
} else {
+ try {
+ if (sparkJobThread.isAlive) {
+ sparkJobThread.interrupt()
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logger.info("spark job thread is interrupted")
+ }
None
}
}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
index 23c9924d1..ca1fe9ada 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
@@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
trainingRDD,
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap,
+ new scala.collection.mutable.HashMap[String, String],
numWorker, 2, null, null)
val boosterCount = boosterRDD.count()
assert(boosterCount === numWorker)
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java
index fc14e361e..5f4351eb1 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java
@@ -1,10 +1,12 @@
package ml.dmlc.xgboost4j;
+import java.io.Serializable;
+
/**
* Labeled data point for training examples.
* Represent a sparse training instance.
*/
-public class LabeledPoint {
+public class LabeledPoint implements Serializable {
/** Label of the point */
public float label;
/** Weight of this data point */
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
index 08abc1afc..5778149f2 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
@@ -24,7 +24,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
- * Booster for xgboost, this is a model API that support interactive build of a XGBOost Model
+ * Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
public class Booster implements Serializable {
private static final Log logger = LogFactory.getLog(Booster.class);
@@ -353,10 +353,26 @@ public class Booster implements Serializable {
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
- * If java natively support HDFS file API, use toByteArray and write the ByteArray,
+ * If java natively support HDFS file API, use toByteArray and write the ByteArray
+ *
+ * @param withStats Controls whether the split statistics are output.
+ * @return dumped model information
+ * @throws XGBoostError native error
+ */
+ private String[] getDumpInfo(boolean withStats) throws XGBoostError {
+ int statsFlag = 0;
+ if (withStats) {
+ statsFlag = 1;
+ }
+ String[][] modelInfos = new String[1][];
+ JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
+ return modelInfos[0];
+ }
+
+ /**
*
* @return the saved byte array.
- * @throws XGBoostError
+ * @throws XGBoostError native error
*/
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java
index 2a52d0b9b..d2ff3b612 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java
@@ -1,5 +1,6 @@
package ml.dmlc.xgboost4j.java;
+import java.io.Serializable;
import java.util.Iterator;
import ml.dmlc.xgboost4j.LabeledPoint;
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java
index 0bc069048..3429dc3dd 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java
@@ -1,6 +1,7 @@
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
+import java.io.Serializable;
import java.util.Map;
import org.apache.commons.logging.Log;
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
index 762cff7bf..5b04ac432 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
@@ -47,8 +47,15 @@ public class RabitTracker {
while ((line = reader.readLine()) != null) {
trackerProcessLogger.info(line);
}
+ trackerProcess.get().waitFor();
+ trackerProcessLogger.info("Tracker Process ends with exit code " +
+ trackerProcess.get().exitValue());
} catch (IOException ex) {
trackerProcessLogger.error(ex.toString());
+ } catch (InterruptedException ie) {
+ // we should not get here as RabitTracker is accessed in the main thread
+ ie.printStackTrace();
+ logger.error("the RabitTracker thread is terminated unexpectedly");
}
}
}
@@ -134,15 +141,18 @@ public class RabitTracker {
}
}
- public void waitFor() {
+ public int waitFor() {
try {
trackerProcess.get().waitFor();
- logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
+ int returnVal = trackerProcess.get().exitValue();
+ logger.info("Tracker Process ends with exit code " + returnVal);
stop();
+ return returnVal;
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
+ return 1;
}
}
}