Merge pull request #927 from CodingCat/spark_example

Spark example
This commit is contained in:
Tianqi Chen 2016-03-06 15:46:54 -08:00
commit 79f9fceb6b
11 changed files with 166 additions and 62 deletions

View File

@ -161,10 +161,5 @@
<version>2.2.6</version> <version>2.2.6</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.typesafe</groupId>
<artifactId>config</artifactId>
<version>1.2.1</version>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -25,7 +25,7 @@
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId> <artifactId>xgboost4j-spark</artifactId>
<version>0.1</version> <version>0.1</version>
</dependency> </dependency>
<dependency> <dependency>

View File

@ -0,0 +1,74 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.demo
import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.regression.LabeledPoint
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.spark.XGBoost
object DistTrainWithSpark {
private def readFile(filePath: String): List[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
for (sample <- file.getLines()) {
sampleList += fromSVMStringToLabeledPoint(sample)
}
sampleList.toList
}
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures(0).toInt
val features = labelAndFeatures.tail
val denseFeature = new Array[Double](129)
for (feature <- features) {
val idAndValue = feature.split(":")
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
}
LabeledPoint(label, new DenseVector(denseFeature))
}
def main(args: Array[String]): Unit = {
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
if (args.length != 4) {
println(
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
sys.exit(1)
}
val sc = new SparkContext()
val inputTrainPath = args(2)
val inputTestPath = args(3)
val trainingLabeledPoints = readFile(inputTrainPath)
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
val testLabeledPoints = readFile(inputTestPath).iterator
val testMatrix = new DMatrix(testLabeledPoints, null)
val booster = XGBoost.train(trainRDD,
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
}
}

View File

@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
import ml.dmlc.xgboost4j.LabeledPoint import ml.dmlc.xgboost4j.LabeledPoint
private[spark] object DataUtils extends Serializable { object DataUtils extends Serializable {
implicit def fromSparkToXGBoostLabeledPointsAsJava(
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
fromSparkToXGBoostLabeledPoints(sps).asJava
}
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]): implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
java.util.Iterator[LabeledPoint] = { Iterator[LabeledPoint] = {
(for (p <- sps) yield { for (p <- sps) yield {
p.features match { p.features match {
case denseFeature: DenseVector => case denseFeature: DenseVector =>
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat)) LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices, LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
sparseFeature.values.map(_.toFloat)) sparseFeature.values.map(_.toFloat))
} }
}).asJava }
}
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
}
private def fetchUpdateFromVector(feature: Vector) = feature match {
case denseFeature: DenseVector =>
fetchUpdateFromSparseVector(denseFeature.toSparse)
case sparseFeature: SparseVector =>
fetchUpdateFromSparseVector(sparseFeature)
} }
} }

View File

@ -16,10 +16,11 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.immutable.HashMap import scala.collection.mutable
import scala.collection.JavaConverters._
import com.typesafe.config.Config import org.apache.commons.logging.LogFactory
import org.apache.spark.{TaskContext, SparkContext} import org.apache.spark.TaskContext
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
@ -28,6 +29,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost extends Serializable { object XGBoost extends Serializable {
var boosters: RDD[Booster] = null
private val logger = LogFactory.getLog("XGBoostSpark")
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = { implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster) new XGBoostModel(booster)
} }
@ -35,45 +39,51 @@ object XGBoost extends Serializable {
private[spark] def buildDistributedBoosters( private[spark] def buildDistributedBoosters(
trainingData: RDD[LabeledPoint], trainingData: RDD[LabeledPoint],
xgBoostConfMap: Map[String, AnyRef], xgBoostConfMap: Map[String, AnyRef],
rabitEnv: mutable.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
import DataUtils._ import DataUtils._
val sc = trainingData.sparkContext trainingData.repartition(numWorkers).mapPartitions {
val tracker = new RabitTracker(numWorkers) trainingSamples =>
if (tracker.start()) { rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
trainingData.repartition(numWorkers).mapPartitions { Rabit.init(rabitEnv.asJava)
trainingSamples => val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
Rabit.init(new java.util.HashMap[String, String]() { val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval)
}) Rabit.shutdown()
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) Iterator(booster)
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, }.cache()
watches = new HashMap[String, DMatrix], obj, eval)
Rabit.shutdown()
Iterator(booster)
}.cache()
} else {
null
}
} }
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null, def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
eval: EvalTrait = null): Option[XGBoostModel] = { obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
import DataUtils._ val numWorkers = trainingData.partitions.length
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
val sc = trainingData.sparkContext val sc = trainingData.sparkContext
val tracker = new RabitTracker(numWorkers) val tracker = new RabitTracker(numWorkers)
if (tracker.start()) { require(tracker.start(), "FAULT: Failed to start tracker")
// TODO: build configuration map from config boosters = buildDistributedBoosters(trainingData, configMap,
val xgBoostConfigMap = new HashMap[String, AnyRef]() tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round, @volatile var booster: Booster = null
obj, eval) val sparkJobThread = new Thread() {
// force the job override def run() {
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) // force the job
tracker.waitFor() boosters.foreachPartition(_ => ())
// TODO: how to choose best model }
Some(boosters.first()) }
sparkJobThread.start()
val returnVal = tracker.waitFor()
logger.info(s"Rabit returns with exit code $returnVal")
if (returnVal == 0) {
booster = boosters.first()
Some(booster)
} else { } else {
try {
if (sparkJobThread.isAlive) {
sparkJobThread.interrupt()
}
} catch {
case ie: InterruptedException =>
logger.info("spark job thread is interrupted")
}
None None
} }
} }

View File

@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
trainingRDD, trainingRDD,
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap, "objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
numWorker, 2, null, null) numWorker, 2, null, null)
val boosterCount = boosterRDD.count() val boosterCount = boosterRDD.count()
assert(boosterCount === numWorker) assert(boosterCount === numWorker)

View File

@ -1,10 +1,12 @@
package ml.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
import java.io.Serializable;
/** /**
* Labeled data point for training examples. * Labeled data point for training examples.
* Represent a sparse training instance. * Represent a sparse training instance.
*/ */
public class LabeledPoint { public class LabeledPoint implements Serializable {
/** Label of the point */ /** Label of the point */
public float label; public float label;
/** Weight of this data point */ /** Weight of this data point */

View File

@ -24,7 +24,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
/** /**
* Booster for xgboost, this is a model API that support interactive build of a XGBOost Model * Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/ */
public class Booster implements Serializable { public class Booster implements Serializable {
private static final Log logger = LogFactory.getLog(Booster.class); private static final Log logger = LogFactory.getLog(Booster.class);
@ -353,10 +353,26 @@ public class Booster implements Serializable {
* Save the model as byte array representation. * Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings. * Write these bytes to a file will give compatible format with other xgboost bindings.
* *
* If java natively support HDFS file API, use toByteArray and write the ByteArray, * If java natively support HDFS file API, use toByteArray and write the ByteArray
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
}
/**
* *
* @return the saved byte array. * @return the saved byte array.
* @throws XGBoostError * @throws XGBoostError native error
*/ */
public byte[] toByteArray() throws XGBoostError { public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][]; byte[][] bytes = new byte[1][];

View File

@ -1,5 +1,6 @@
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import java.io.Serializable;
import java.util.Iterator; import java.util.Iterator;
import ml.dmlc.xgboost4j.LabeledPoint; import ml.dmlc.xgboost4j.LabeledPoint;

View File

@ -1,6 +1,7 @@
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable;
import java.util.Map; import java.util.Map;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;

View File

@ -47,8 +47,15 @@ public class RabitTracker {
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
trackerProcessLogger.info(line); trackerProcessLogger.info(line);
} }
trackerProcess.get().waitFor();
trackerProcessLogger.info("Tracker Process ends with exit code " +
trackerProcess.get().exitValue());
} catch (IOException ex) { } catch (IOException ex) {
trackerProcessLogger.error(ex.toString()); trackerProcessLogger.error(ex.toString());
} catch (InterruptedException ie) {
// we should not get here as RabitTracker is accessed in the main thread
ie.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
} }
} }
} }
@ -134,15 +141,18 @@ public class RabitTracker {
} }
} }
public void waitFor() { public int waitFor() {
try { try {
trackerProcess.get().waitFor(); trackerProcess.get().waitFor();
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue()); int returnVal = trackerProcess.get().exitValue();
logger.info("Tracker Process ends with exit code " + returnVal);
stop(); stop();
return returnVal;
} catch (InterruptedException e) { } catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread // we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace(); e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly"); logger.error("the RabitTracker thread is terminated unexpectedly");
return 1;
} }
} }
} }