diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 43f602df6..db6bc8a98 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -161,10 +161,5 @@ 2.2.6 test - - com.typesafe - config - 1.2.1 - diff --git a/jvm-packages/xgboost4j-demo/pom.xml b/jvm-packages/xgboost4j-demo/pom.xml index bef184adb..e076af63d 100644 --- a/jvm-packages/xgboost4j-demo/pom.xml +++ b/jvm-packages/xgboost4j-demo/pom.xml @@ -25,7 +25,7 @@ ml.dmlc - xgboost4j + xgboost4j-spark 0.1 diff --git a/jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala b/jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala new file mode 100644 index 000000000..8fd794423 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala @@ -0,0 +1,74 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark.demo + +import java.io.File + +import scala.collection.mutable.ListBuffer +import scala.io.Source + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.regression.LabeledPoint + +import ml.dmlc.xgboost4j.scala.DMatrix +import ml.dmlc.xgboost4j.scala.spark.XGBoost + + +object DistTrainWithSpark { + + private def readFile(filePath: String): List[LabeledPoint] = { + val file = Source.fromFile(new File(filePath)) + val sampleList = new ListBuffer[LabeledPoint] + for (sample <- file.getLines()) { + sampleList += fromSVMStringToLabeledPoint(sample) + } + sampleList.toList + } + + private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = { + val labelAndFeatures = line.split(" ") + val label = labelAndFeatures(0).toInt + val features = labelAndFeatures.tail + val denseFeature = new Array[Double](129) + for (feature <- features) { + val idAndValue = feature.split(":") + denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble + } + LabeledPoint(label, new DenseVector(denseFeature)) + } + + def main(args: Array[String]): Unit = { + import ml.dmlc.xgboost4j.scala.spark.DataUtils._ + if (args.length != 4) { + println( + "usage: program number_of_trainingset_partitions num_of_rounds training_path test_path") + sys.exit(1) + } + val sc = new SparkContext() + val inputTrainPath = args(2) + val inputTestPath = args(3) + val trainingLabeledPoints = readFile(inputTrainPath) + val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt) + val testLabeledPoints = readFile(inputTestPath).iterator + val testMatrix = new DMatrix(testLabeledPoints, null) + val booster = XGBoost.train(trainRDD, + List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + "objective" -> "binary:logistic").toMap, args(1).toInt, null, null) + booster.map(boosterInstance => boosterInstance.predict(testMatrix)) + } +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala index 12fb545c9..d61cb9fc1 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala @@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint} import ml.dmlc.xgboost4j.LabeledPoint -private[spark] object DataUtils extends Serializable { +object DataUtils extends Serializable { + + implicit def fromSparkToXGBoostLabeledPointsAsJava( + sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = { + fromSparkToXGBoostLabeledPoints(sps).asJava + } implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]): - java.util.Iterator[LabeledPoint] = { - (for (p <- sps) yield { + Iterator[LabeledPoint] = { + for (p <- sps) yield { p.features match { case denseFeature: DenseVector => LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat)) @@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable { LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices, sparseFeature.values.map(_.toFloat)) } - }).asJava - } - - private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = { - (sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList) - } - - private def fetchUpdateFromVector(feature: Vector) = feature match { - case denseFeature: DenseVector => - fetchUpdateFromSparseVector(denseFeature.toSparse) - case sparseFeature: SparseVector => - fetchUpdateFromSparseVector(sparseFeature) + } } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8b0d0a71e..a7c802dc1 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -16,10 +16,11 @@ package ml.dmlc.xgboost4j.scala.spark -import scala.collection.immutable.HashMap +import scala.collection.mutable +import scala.collection.JavaConverters._ -import com.typesafe.config.Config -import org.apache.spark.{TaskContext, SparkContext} +import org.apache.commons.logging.LogFactory +import org.apache.spark.TaskContext import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -28,6 +29,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} object XGBoost extends Serializable { + var boosters: RDD[Booster] = null + private val logger = LogFactory.getLog("XGBoostSpark") + implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = { new XGBoostModel(booster) } @@ -35,45 +39,51 @@ object XGBoost extends Serializable { private[spark] def buildDistributedBoosters( trainingData: RDD[LabeledPoint], xgBoostConfMap: Map[String, AnyRef], + rabitEnv: mutable.Map[String, String], numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { import DataUtils._ - val sc = trainingData.sparkContext - val tracker = new RabitTracker(numWorkers) - if (tracker.start()) { - trainingData.repartition(numWorkers).mapPartitions { - trainingSamples => - Rabit.init(new java.util.HashMap[String, String]() { - put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) - }) - val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) - val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, - watches = new HashMap[String, DMatrix], obj, eval) - Rabit.shutdown() - Iterator(booster) - }.cache() - } else { - null - } + trainingData.repartition(numWorkers).mapPartitions { + trainingSamples => + rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) + Rabit.init(rabitEnv.asJava) + val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) + val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, + watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval) + Rabit.shutdown() + Iterator(booster) + }.cache() } - def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null, - eval: EvalTrait = null): Option[XGBoostModel] = { - import DataUtils._ - val numWorkers = config.getInt("numWorkers") - val round = config.getInt("round") + def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int, + obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = { + val numWorkers = trainingData.partitions.length val sc = trainingData.sparkContext val tracker = new RabitTracker(numWorkers) - if (tracker.start()) { - // TODO: build configuration map from config - val xgBoostConfigMap = new HashMap[String, AnyRef]() - val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round, - obj, eval) - // force the job - sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) - tracker.waitFor() - // TODO: how to choose best model - Some(boosters.first()) + require(tracker.start(), "FAULT: Failed to start tracker") + boosters = buildDistributedBoosters(trainingData, configMap, + tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) + @volatile var booster: Booster = null + val sparkJobThread = new Thread() { + override def run() { + // force the job + boosters.foreachPartition(_ => ()) + } + } + sparkJobThread.start() + val returnVal = tracker.waitFor() + logger.info(s"Rabit returns with exit code $returnVal") + if (returnVal == 0) { + booster = boosters.first() + Some(booster) } else { + try { + if (sparkJobThread.isAlive) { + sparkJobThread.interrupt() + } + } catch { + case ie: InterruptedException => + logger.info("spark job thread is interrupted") + } None } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index 23c9924d1..ca1fe9ada 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll { trainingRDD, List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", "objective" -> "binary:logistic").toMap, + new scala.collection.mutable.HashMap[String, String], numWorker, 2, null, null) val boosterCount = boosterRDD.count() assert(boosterCount === numWorker) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java index fc14e361e..5f4351eb1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java @@ -1,10 +1,12 @@ package ml.dmlc.xgboost4j; +import java.io.Serializable; + /** * Labeled data point for training examples. * Represent a sparse training instance. */ -public class LabeledPoint { +public class LabeledPoint implements Serializable { /** Label of the point */ public float label; /** Weight of this data point */ diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 08abc1afc..5778149f2 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -24,7 +24,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; /** - * Booster for xgboost, this is a model API that support interactive build of a XGBOost Model + * Booster for xgboost, this is a model API that support interactive build of a XGBoost Model */ public class Booster implements Serializable { private static final Log logger = LogFactory.getLog(Booster.class); @@ -353,10 +353,26 @@ public class Booster implements Serializable { * Save the model as byte array representation. * Write these bytes to a file will give compatible format with other xgboost bindings. * - * If java natively support HDFS file API, use toByteArray and write the ByteArray, + * If java natively support HDFS file API, use toByteArray and write the ByteArray + * + * @param withStats Controls whether the split statistics are output. + * @return dumped model information + * @throws XGBoostError native error + */ + private String[] getDumpInfo(boolean withStats) throws XGBoostError { + int statsFlag = 0; + if (withStats) { + statsFlag = 1; + } + String[][] modelInfos = new String[1][]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos)); + return modelInfos[0]; + } + + /** * * @return the saved byte array. - * @throws XGBoostError + * @throws XGBoostError native error */ public byte[] toByteArray() throws XGBoostError { byte[][] bytes = new byte[1][]; diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java index 2a52d0b9b..d2ff3b612 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java @@ -1,5 +1,6 @@ package ml.dmlc.xgboost4j.java; +import java.io.Serializable; import java.util.Iterator; import ml.dmlc.xgboost4j.LabeledPoint; diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java index 0bc069048..3429dc3dd 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java @@ -1,6 +1,7 @@ package ml.dmlc.xgboost4j.java; import java.io.IOException; +import java.io.Serializable; import java.util.Map; import org.apache.commons.logging.Log; diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index 762cff7bf..5b04ac432 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -47,8 +47,15 @@ public class RabitTracker { while ((line = reader.readLine()) != null) { trackerProcessLogger.info(line); } + trackerProcess.get().waitFor(); + trackerProcessLogger.info("Tracker Process ends with exit code " + + trackerProcess.get().exitValue()); } catch (IOException ex) { trackerProcessLogger.error(ex.toString()); + } catch (InterruptedException ie) { + // we should not get here as RabitTracker is accessed in the main thread + ie.printStackTrace(); + logger.error("the RabitTracker thread is terminated unexpectedly"); } } } @@ -134,15 +141,18 @@ public class RabitTracker { } } - public void waitFor() { + public int waitFor() { try { trackerProcess.get().waitFor(); - logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue()); + int returnVal = trackerProcess.get().exitValue(); + logger.info("Tracker Process ends with exit code " + returnVal); stop(); + return returnVal; } catch (InterruptedException e) { // we should not get here as RabitTracker is accessed in the main thread e.printStackTrace(); logger.error("the RabitTracker thread is terminated unexpectedly"); + return 1; } } }