Merge pull request #934 from tqchen/master
[Spark] Refactor train, predict, add save
This commit is contained in:
commit
6f5632dd6e
@ -1,6 +1,6 @@
|
|||||||
# XGBoost4J: Distributed XGBoost for Scala/Java
|
# XGBoost4J: Distributed XGBoost for Scala/Java
|
||||||
[](https://travis-ci.org/dmlc/xgboost)
|
[](https://travis-ci.org/dmlc/xgboost)
|
||||||
[](https://xgboost.readthedocs.org/en/latest/jvm/index.html)
|
[](https://xgboost.readthedocs.org/en/latest/jvm/index.html)
|
||||||
[](../LICENSE)
|
[](../LICENSE)
|
||||||
|
|
||||||
[Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) |
|
[Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) |
|
||||||
@ -72,3 +72,32 @@ object DistTrainWithFlink {
|
|||||||
```
|
```
|
||||||
|
|
||||||
### XGBoost Spark
|
### XGBoost Spark
|
||||||
|
```scala
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||||
|
|
||||||
|
object DistTrainWithSpark {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
if (args.length != 3) {
|
||||||
|
println(
|
||||||
|
"usage: program num_of_rounds training_path model_path")
|
||||||
|
sys.exit(1)
|
||||||
|
}
|
||||||
|
val sc = new SparkContext()
|
||||||
|
val inputTrainPath = args(1)
|
||||||
|
val outputModelPath = args(2)
|
||||||
|
// number of iterations
|
||||||
|
val numRound = args(0).toInt
|
||||||
|
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
|
||||||
|
// training parameters
|
||||||
|
val paramMap = List(
|
||||||
|
"eta" -> 0.1f,
|
||||||
|
"max_depth" -> 2,
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val model = XGBoost.train(trainRDD, paramMap, numRound)
|
||||||
|
// save model to HDFS path
|
||||||
|
model.saveModelToHadoop(outputModelPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
@ -16,59 +16,30 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.example
|
package ml.dmlc.xgboost4j.scala.spark.example
|
||||||
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
import scala.collection.mutable.ListBuffer
|
|
||||||
import scala.io.Source
|
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.mllib.linalg.DenseVector
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||||
|
|
||||||
|
|
||||||
object DistTrainWithSpark {
|
object DistTrainWithSpark {
|
||||||
|
|
||||||
private def readFile(filePath: String): List[LabeledPoint] = {
|
|
||||||
val file = Source.fromFile(new File(filePath))
|
|
||||||
val sampleList = new ListBuffer[LabeledPoint]
|
|
||||||
for (sample <- file.getLines()) {
|
|
||||||
sampleList += fromSVMStringToLabeledPoint(sample)
|
|
||||||
}
|
|
||||||
sampleList.toList
|
|
||||||
}
|
|
||||||
|
|
||||||
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
|
||||||
val labelAndFeatures = line.split(" ")
|
|
||||||
val label = labelAndFeatures(0).toInt
|
|
||||||
val features = labelAndFeatures.tail
|
|
||||||
val denseFeature = new Array[Double](129)
|
|
||||||
for (feature <- features) {
|
|
||||||
val idAndValue = feature.split(":")
|
|
||||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
|
||||||
}
|
|
||||||
LabeledPoint(label, new DenseVector(denseFeature))
|
|
||||||
}
|
|
||||||
|
|
||||||
def main(args: Array[String]): Unit = {
|
def main(args: Array[String]): Unit = {
|
||||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
|
if (args.length != 3) {
|
||||||
if (args.length != 4) {
|
|
||||||
println(
|
println(
|
||||||
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
|
"usage: program num_of_rounds training_path model_path")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
}
|
}
|
||||||
val sc = new SparkContext()
|
val sc = new SparkContext()
|
||||||
val inputTrainPath = args(2)
|
val inputTrainPath = args(1)
|
||||||
val inputTestPath = args(3)
|
val outputModelPath = args(2)
|
||||||
val trainingLabeledPoints = readFile(inputTrainPath)
|
// number of iterations
|
||||||
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
|
val numRound = args(0).toInt
|
||||||
val testLabeledPoints = readFile(inputTestPath).iterator
|
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
|
||||||
val testMatrix = new DMatrix(testLabeledPoints, null)
|
// training parameters
|
||||||
val booster = XGBoost.train(trainRDD,
|
val paramMap = List(
|
||||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
"eta" -> 0.1f,
|
||||||
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
|
"max_depth" -> 2,
|
||||||
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val model = XGBoost.train(trainRDD, paramMap, numRound)
|
||||||
|
// save model to HDFS path
|
||||||
|
model.saveModelToHadoop(outputModelPath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,13 +24,12 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
|||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
|
|
||||||
object DataUtils extends Serializable {
|
object DataUtils extends Serializable {
|
||||||
|
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
|
||||||
implicit def fromSparkToXGBoostLabeledPointsAsJava(
|
: java.util.Iterator[LabeledPoint] = {
|
||||||
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
|
fromSparkPointsToXGBoostPoints(sps).asJava
|
||||||
fromSparkToXGBoostLabeledPoints(sps).asJava
|
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[SparkLabeledPoint]):
|
||||||
Iterator[LabeledPoint] = {
|
Iterator[LabeledPoint] = {
|
||||||
for (p <- sps) yield {
|
for (p <- sps) yield {
|
||||||
p.features match {
|
p.features match {
|
||||||
@ -42,4 +41,21 @@ object DataUtils extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
implicit def fromSparkVectorToXGBoostPointsJava(sps: Iterator[Vector])
|
||||||
|
: java.util.Iterator[LabeledPoint] = {
|
||||||
|
fromSparkVectorToXGBoostPoints(sps).asJava
|
||||||
|
}
|
||||||
|
implicit def fromSparkVectorToXGBoostPoints(sps: Iterator[Vector])
|
||||||
|
: Iterator[LabeledPoint] = {
|
||||||
|
for (p <- sps) yield {
|
||||||
|
p match {
|
||||||
|
case denseFeature: DenseVector =>
|
||||||
|
LabeledPoint.fromDenseVector(0.0f, denseFeature.values.map(_.toFloat))
|
||||||
|
case sparseFeature: SparseVector =>
|
||||||
|
LabeledPoint.fromSparseVector(0.0f, sparseFeature.indices,
|
||||||
|
sparseFeature.values.map(_.toFloat))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,6 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
|
||||||
|
import org.apache.hadoop.conf.Configuration
|
||||||
|
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||||
|
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
@ -28,7 +32,6 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker}
|
|||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
|
|
||||||
var boosters: RDD[Booster] = null
|
var boosters: RDD[Booster] = null
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
@ -38,7 +41,7 @@ object XGBoost extends Serializable {
|
|||||||
|
|
||||||
private[spark] def buildDistributedBoosters(
|
private[spark] def buildDistributedBoosters(
|
||||||
trainingData: RDD[LabeledPoint],
|
trainingData: RDD[LabeledPoint],
|
||||||
xgBoostConfMap: Map[String, AnyRef],
|
xgBoostConfMap: Map[String, Any],
|
||||||
rabitEnv: mutable.Map[String, String],
|
rabitEnv: mutable.Map[String, String],
|
||||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
@ -54,13 +57,13 @@ object XGBoost extends Serializable {
|
|||||||
}.cache()
|
}.cache()
|
||||||
}
|
}
|
||||||
|
|
||||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
|
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
|
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||||
val numWorkers = trainingData.partitions.length
|
val numWorkers = trainingData.partitions.length
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val tracker = new RabitTracker(numWorkers)
|
val tracker = new RabitTracker(numWorkers)
|
||||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||||
boosters = buildDistributedBoosters(trainingData, configMap,
|
val boosters = buildDistributedBoosters(trainingData, configMap,
|
||||||
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||||
@volatile var booster: Booster = null
|
@volatile var booster: Booster = null
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
@ -74,7 +77,7 @@ object XGBoost extends Serializable {
|
|||||||
logger.info(s"Rabit returns with exit code $returnVal")
|
logger.info(s"Rabit returns with exit code $returnVal")
|
||||||
if (returnVal == 0) {
|
if (returnVal == 0) {
|
||||||
booster = boosters.first()
|
booster = boosters.first()
|
||||||
Some(booster)
|
Some(booster).get
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
if (sparkJobThread.isAlive) {
|
if (sparkJobThread.isAlive) {
|
||||||
@ -84,7 +87,21 @@ object XGBoost extends Serializable {
|
|||||||
case ie: InterruptedException =>
|
case ie: InterruptedException =>
|
||||||
logger.info("spark job thread is interrupted")
|
logger.info("spark job thread is interrupted")
|
||||||
}
|
}
|
||||||
None
|
null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||||
|
*
|
||||||
|
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||||
|
* @return The loaded model
|
||||||
|
*/
|
||||||
|
def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
|
||||||
|
new XGBoostModel(
|
||||||
|
SXGBoost.loadModel(
|
||||||
|
FileSystem
|
||||||
|
.get(new Configuration)
|
||||||
|
.open(new Path(modelPath))))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,15 +16,20 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
import org.apache.hadoop.conf.Configuration
|
||||||
|
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||||
|
import org.apache.spark.mllib.linalg.Vector
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||||
|
|
||||||
class XGBoostModel(booster: Booster) extends Serializable {
|
class XGBoostModel(booster: Booster) extends Serializable {
|
||||||
|
/**
|
||||||
def predict(testSet: RDD[SparkLabeledPoint]): RDD[Array[Array[Float]]] = {
|
* Predict result given testRDD
|
||||||
|
* @param testSet the testSet of Data vectors
|
||||||
|
* @return The predicted RDD
|
||||||
|
*/
|
||||||
|
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val broadcastBooster = testSet.sparkContext.broadcast(booster)
|
val broadcastBooster = testSet.sparkContext.broadcast(booster)
|
||||||
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
|
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
|
||||||
@ -37,4 +42,15 @@ class XGBoostModel(booster: Booster) extends Serializable {
|
|||||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||||
booster.predict(testSet)
|
booster.predict(testSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the model as a Hadoop filesystem file.
|
||||||
|
*
|
||||||
|
* @param modelPath The model path as in Hadoop path.
|
||||||
|
*/
|
||||||
|
def saveModelToHadoop(modelPath: String): Unit = {
|
||||||
|
booster.saveModel(FileSystem
|
||||||
|
.get(new Configuration)
|
||||||
|
.create(new Path(modelPath)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user