Merge pull request #927 from CodingCat/spark_example

Spark example
This commit is contained in:
Tianqi Chen 2016-03-06 15:46:54 -08:00
commit 79f9fceb6b
11 changed files with 166 additions and 62 deletions

View File

@ -161,10 +161,5 @@
<version>2.2.6</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.typesafe</groupId>
<artifactId>config</artifactId>
<version>1.2.1</version>
</dependency>
</dependencies>
</project>

View File

@ -25,7 +25,7 @@
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<artifactId>xgboost4j-spark</artifactId>
<version>0.1</version>
</dependency>
<dependency>

View File

@ -0,0 +1,74 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.demo
import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.regression.LabeledPoint
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.spark.XGBoost
object DistTrainWithSpark {
private def readFile(filePath: String): List[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
for (sample <- file.getLines()) {
sampleList += fromSVMStringToLabeledPoint(sample)
}
sampleList.toList
}
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures(0).toInt
val features = labelAndFeatures.tail
val denseFeature = new Array[Double](129)
for (feature <- features) {
val idAndValue = feature.split(":")
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
}
LabeledPoint(label, new DenseVector(denseFeature))
}
def main(args: Array[String]): Unit = {
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
if (args.length != 4) {
println(
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
sys.exit(1)
}
val sc = new SparkContext()
val inputTrainPath = args(2)
val inputTestPath = args(3)
val trainingLabeledPoints = readFile(inputTrainPath)
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
val testLabeledPoints = readFile(inputTestPath).iterator
val testMatrix = new DMatrix(testLabeledPoints, null)
val booster = XGBoost.train(trainRDD,
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
}
}

View File

@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
import ml.dmlc.xgboost4j.LabeledPoint
private[spark] object DataUtils extends Serializable {
object DataUtils extends Serializable {
implicit def fromSparkToXGBoostLabeledPointsAsJava(
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
fromSparkToXGBoostLabeledPoints(sps).asJava
}
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
java.util.Iterator[LabeledPoint] = {
(for (p <- sps) yield {
Iterator[LabeledPoint] = {
for (p <- sps) yield {
p.features match {
case denseFeature: DenseVector =>
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
sparseFeature.values.map(_.toFloat))
}
}).asJava
}
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
}
private def fetchUpdateFromVector(feature: Vector) = feature match {
case denseFeature: DenseVector =>
fetchUpdateFromSparseVector(denseFeature.toSparse)
case sparseFeature: SparseVector =>
fetchUpdateFromSparseVector(sparseFeature)
}
}

View File

@ -16,10 +16,11 @@
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.immutable.HashMap
import scala.collection.mutable
import scala.collection.JavaConverters._
import com.typesafe.config.Config
import org.apache.spark.{TaskContext, SparkContext}
import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@ -28,6 +29,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost extends Serializable {
var boosters: RDD[Booster] = null
private val logger = LogFactory.getLog("XGBoostSpark")
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster)
}
@ -35,45 +39,51 @@ object XGBoost extends Serializable {
private[spark] def buildDistributedBoosters(
trainingData: RDD[LabeledPoint],
xgBoostConfMap: Map[String, AnyRef],
rabitEnv: mutable.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
import DataUtils._
val sc = trainingData.sparkContext
val tracker = new RabitTracker(numWorkers)
if (tracker.start()) {
trainingData.repartition(numWorkers).mapPartitions {
trainingSamples =>
Rabit.init(new java.util.HashMap[String, String]() {
put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
})
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
watches = new HashMap[String, DMatrix], obj, eval)
watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval)
Rabit.shutdown()
Iterator(booster)
}.cache()
} else {
null
}
}
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
eval: EvalTrait = null): Option[XGBoostModel] = {
import DataUtils._
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
val numWorkers = trainingData.partitions.length
val sc = trainingData.sparkContext
val tracker = new RabitTracker(numWorkers)
if (tracker.start()) {
// TODO: build configuration map from config
val xgBoostConfigMap = new HashMap[String, AnyRef]()
val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
obj, eval)
require(tracker.start(), "FAULT: Failed to start tracker")
boosters = buildDistributedBoosters(trainingData, configMap,
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
@volatile var booster: Booster = null
val sparkJobThread = new Thread() {
override def run() {
// force the job
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
tracker.waitFor()
// TODO: how to choose best model
Some(boosters.first())
boosters.foreachPartition(_ => ())
}
}
sparkJobThread.start()
val returnVal = tracker.waitFor()
logger.info(s"Rabit returns with exit code $returnVal")
if (returnVal == 0) {
booster = boosters.first()
Some(booster)
} else {
try {
if (sparkJobThread.isAlive) {
sparkJobThread.interrupt()
}
} catch {
case ie: InterruptedException =>
logger.info("spark job thread is interrupted")
}
None
}
}

View File

@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
trainingRDD,
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
numWorker, 2, null, null)
val boosterCount = boosterRDD.count()
assert(boosterCount === numWorker)

View File

@ -1,10 +1,12 @@
package ml.dmlc.xgboost4j;
import java.io.Serializable;
/**
* Labeled data point for training examples.
* Represent a sparse training instance.
*/
public class LabeledPoint {
public class LabeledPoint implements Serializable {
/** Label of the point */
public float label;
/** Weight of this data point */

View File

@ -24,7 +24,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Booster for xgboost, this is a model API that support interactive build of a XGBOost Model
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
public class Booster implements Serializable {
private static final Log logger = LogFactory.getLog(Booster.class);
@ -353,10 +353,26 @@ public class Booster implements Serializable {
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
* If java natively support HDFS file API, use toByteArray and write the ByteArray
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
}
/**
*
* @return the saved byte array.
* @throws XGBoostError
* @throws XGBoostError native error
*/
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];

View File

@ -1,5 +1,6 @@
package ml.dmlc.xgboost4j.java;
import java.io.Serializable;
import java.util.Iterator;
import ml.dmlc.xgboost4j.LabeledPoint;

View File

@ -1,6 +1,7 @@
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.Serializable;
import java.util.Map;
import org.apache.commons.logging.Log;

View File

@ -47,8 +47,15 @@ public class RabitTracker {
while ((line = reader.readLine()) != null) {
trackerProcessLogger.info(line);
}
trackerProcess.get().waitFor();
trackerProcessLogger.info("Tracker Process ends with exit code " +
trackerProcess.get().exitValue());
} catch (IOException ex) {
trackerProcessLogger.error(ex.toString());
} catch (InterruptedException ie) {
// we should not get here as RabitTracker is accessed in the main thread
ie.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
}
}
}
@ -134,15 +141,18 @@ public class RabitTracker {
}
}
public void waitFor() {
public int waitFor() {
try {
trackerProcess.get().waitFor();
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
int returnVal = trackerProcess.get().exitValue();
logger.info("Tracker Process ends with exit code " + returnVal);
stop();
return returnVal;
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
return 1;
}
}
}