Merge pull request #934 from tqchen/master

[Spark] Refactor train, predict, add save
This commit is contained in:
Tianqi Chen 2016-03-06 21:57:38 -08:00
commit 6f5632dd6e
5 changed files with 112 additions and 63 deletions

View File

@ -72,3 +72,32 @@ object DistTrainWithFlink {
``` ```
### XGBoost Spark ### XGBoost Spark
```scala
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.MLUtils
import ml.dmlc.xgboost4j.scala.spark.XGBoost
object DistTrainWithSpark {
def main(args: Array[String]): Unit = {
if (args.length != 3) {
println(
"usage: program num_of_rounds training_path model_path")
sys.exit(1)
}
val sc = new SparkContext()
val inputTrainPath = args(1)
val outputModelPath = args(2)
// number of iterations
val numRound = args(0).toInt
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
// training parameters
val paramMap = List(
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
val model = XGBoost.train(trainRDD, paramMap, numRound)
// save model to HDFS path
model.saveModelToHadoop(outputModelPath)
}
}
```

View File

@ -16,59 +16,30 @@
package ml.dmlc.xgboost4j.scala.spark.example package ml.dmlc.xgboost4j.scala.spark.example
import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.regression.LabeledPoint
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.spark.XGBoost import ml.dmlc.xgboost4j.scala.spark.XGBoost
object DistTrainWithSpark { object DistTrainWithSpark {
private def readFile(filePath: String): List[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
for (sample <- file.getLines()) {
sampleList += fromSVMStringToLabeledPoint(sample)
}
sampleList.toList
}
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures(0).toInt
val features = labelAndFeatures.tail
val denseFeature = new Array[Double](129)
for (feature <- features) {
val idAndValue = feature.split(":")
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
}
LabeledPoint(label, new DenseVector(denseFeature))
}
def main(args: Array[String]): Unit = { def main(args: Array[String]): Unit = {
import ml.dmlc.xgboost4j.scala.spark.DataUtils._ if (args.length != 3) {
if (args.length != 4) {
println( println(
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path") "usage: program num_of_rounds training_path model_path")
sys.exit(1) sys.exit(1)
} }
val sc = new SparkContext() val sc = new SparkContext()
val inputTrainPath = args(2) val inputTrainPath = args(1)
val inputTestPath = args(3) val outputModelPath = args(2)
val trainingLabeledPoints = readFile(inputTrainPath) // number of iterations
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt) val numRound = args(0).toInt
val testLabeledPoints = readFile(inputTestPath).iterator val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
val testMatrix = new DMatrix(testLabeledPoints, null) // training parameters
val booster = XGBoost.train(trainRDD, val paramMap = List(
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", "eta" -> 0.1f,
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null) "max_depth" -> 2,
booster.map(boosterInstance => boosterInstance.predict(testMatrix)) "objective" -> "binary:logistic").toMap
val model = XGBoost.train(trainRDD, paramMap, numRound)
// save model to HDFS path
model.saveModelToHadoop(outputModelPath)
} }
} }

View File

@ -24,13 +24,12 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
import ml.dmlc.xgboost4j.LabeledPoint import ml.dmlc.xgboost4j.LabeledPoint
object DataUtils extends Serializable { object DataUtils extends Serializable {
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
implicit def fromSparkToXGBoostLabeledPointsAsJava( : java.util.Iterator[LabeledPoint] = {
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = { fromSparkPointsToXGBoostPoints(sps).asJava
fromSparkToXGBoostLabeledPoints(sps).asJava
} }
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]): implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[SparkLabeledPoint]):
Iterator[LabeledPoint] = { Iterator[LabeledPoint] = {
for (p <- sps) yield { for (p <- sps) yield {
p.features match { p.features match {
@ -42,4 +41,21 @@ object DataUtils extends Serializable {
} }
} }
} }
implicit def fromSparkVectorToXGBoostPointsJava(sps: Iterator[Vector])
: java.util.Iterator[LabeledPoint] = {
fromSparkVectorToXGBoostPoints(sps).asJava
}
implicit def fromSparkVectorToXGBoostPoints(sps: Iterator[Vector])
: Iterator[LabeledPoint] = {
for (p <- sps) yield {
p match {
case denseFeature: DenseVector =>
LabeledPoint.fromDenseVector(0.0f, denseFeature.values.map(_.toFloat))
case sparseFeature: SparseVector =>
LabeledPoint.fromSparseVector(0.0f, sparseFeature.indices,
sparseFeature.values.map(_.toFloat))
}
}
}
} }

View File

@ -19,6 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable import scala.collection.mutable
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext import org.apache.spark.TaskContext
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
@ -28,7 +32,6 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost extends Serializable { object XGBoost extends Serializable {
var boosters: RDD[Booster] = null var boosters: RDD[Booster] = null
private val logger = LogFactory.getLog("XGBoostSpark") private val logger = LogFactory.getLog("XGBoostSpark")
@ -38,7 +41,7 @@ object XGBoost extends Serializable {
private[spark] def buildDistributedBoosters( private[spark] def buildDistributedBoosters(
trainingData: RDD[LabeledPoint], trainingData: RDD[LabeledPoint],
xgBoostConfMap: Map[String, AnyRef], xgBoostConfMap: Map[String, Any],
rabitEnv: mutable.Map[String, String], rabitEnv: mutable.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
import DataUtils._ import DataUtils._
@ -54,13 +57,13 @@ object XGBoost extends Serializable {
}.cache() }.cache()
} }
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int, def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = { obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
val numWorkers = trainingData.partitions.length val numWorkers = trainingData.partitions.length
val sc = trainingData.sparkContext val sc = trainingData.sparkContext
val tracker = new RabitTracker(numWorkers) val tracker = new RabitTracker(numWorkers)
require(tracker.start(), "FAULT: Failed to start tracker") require(tracker.start(), "FAULT: Failed to start tracker")
boosters = buildDistributedBoosters(trainingData, configMap, val boosters = buildDistributedBoosters(trainingData, configMap,
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
@volatile var booster: Booster = null @volatile var booster: Booster = null
val sparkJobThread = new Thread() { val sparkJobThread = new Thread() {
@ -74,7 +77,7 @@ object XGBoost extends Serializable {
logger.info(s"Rabit returns with exit code $returnVal") logger.info(s"Rabit returns with exit code $returnVal")
if (returnVal == 0) { if (returnVal == 0) {
booster = boosters.first() booster = boosters.first()
Some(booster) Some(booster).get
} else { } else {
try { try {
if (sparkJobThread.isAlive) { if (sparkJobThread.isAlive) {
@ -84,7 +87,21 @@ object XGBoost extends Serializable {
case ie: InterruptedException => case ie: InterruptedException =>
logger.info("spark job thread is interrupted") logger.info("spark job thread is interrupted")
} }
None null
} }
} }
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
new XGBoostModel(
SXGBoost.loadModel(
FileSystem
.get(new Configuration)
.open(new Path(modelPath))))
}
} }

View File

@ -16,15 +16,20 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint} import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
class XGBoostModel(booster: Booster) extends Serializable { class XGBoostModel(booster: Booster) extends Serializable {
/**
def predict(testSet: RDD[SparkLabeledPoint]): RDD[Array[Array[Float]]] = { * Predict result given testRDD
* @param testSet the testSet of Data vectors
* @return The predicted RDD
*/
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
import DataUtils._ import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(booster) val broadcastBooster = testSet.sparkContext.broadcast(booster)
val dataUtils = testSet.sparkContext.broadcast(DataUtils) val dataUtils = testSet.sparkContext.broadcast(DataUtils)
@ -37,4 +42,15 @@ class XGBoostModel(booster: Booster) extends Serializable {
def predict(testSet: DMatrix): Array[Array[Float]] = { def predict(testSet: DMatrix): Array[Array[Float]] = {
booster.predict(testSet) booster.predict(testSet)
} }
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModelToHadoop(modelPath: String): Unit = {
booster.saveModel(FileSystem
.get(new Configuration)
.create(new Path(modelPath)))
}
} }