[Spark] Refactor train, predict, add save
This commit is contained in:
parent
3402953633
commit
435a0425b9
@ -1,6 +1,6 @@
|
||||
# XGBoost4J: Distributed XGBoost for Scala/Java
|
||||
[](https://travis-ci.org/dmlc/xgboost)
|
||||
[](https://xgboost.readthedocs.org/en/latest/jvm/index.html)
|
||||
[](https://travis-ci.org/dmlc/xgboost)
|
||||
[](https://xgboost.readthedocs.org/en/latest/jvm/index.html)
|
||||
[](../LICENSE)
|
||||
|
||||
[Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) |
|
||||
@ -72,3 +72,32 @@ object DistTrainWithFlink {
|
||||
```
|
||||
|
||||
### XGBoost Spark
|
||||
```scala
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||
|
||||
object DistTrainWithSpark {
|
||||
def main(args: Array[String]): Unit = {
|
||||
if (args.length != 3) {
|
||||
println(
|
||||
"usage: program num_of_rounds training_path model_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
val sc = new SparkContext()
|
||||
val inputTrainPath = args(1)
|
||||
val outputModelPath = args(2)
|
||||
// number of iterations
|
||||
val numRound = args(0).toInt
|
||||
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
|
||||
// training parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val model = XGBoost.train(trainRDD, paramMap, numRound)
|
||||
// save model to HDFS path
|
||||
model.saveModelToHadoop(outputModelPath)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@ -16,59 +16,30 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.example
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.linalg.DenseVector
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||
|
||||
|
||||
object DistTrainWithSpark {
|
||||
|
||||
private def readFile(filePath: String): List[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
for (sample <- file.getLines()) {
|
||||
sampleList += fromSVMStringToLabeledPoint(sample)
|
||||
}
|
||||
sampleList.toList
|
||||
}
|
||||
|
||||
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
||||
val labelAndFeatures = line.split(" ")
|
||||
val label = labelAndFeatures(0).toInt
|
||||
val features = labelAndFeatures.tail
|
||||
val denseFeature = new Array[Double](129)
|
||||
for (feature <- features) {
|
||||
val idAndValue = feature.split(":")
|
||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||
}
|
||||
LabeledPoint(label, new DenseVector(denseFeature))
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
|
||||
if (args.length != 4) {
|
||||
if (args.length != 3) {
|
||||
println(
|
||||
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
|
||||
"usage: program num_of_rounds training_path model_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
val sc = new SparkContext()
|
||||
val inputTrainPath = args(2)
|
||||
val inputTestPath = args(3)
|
||||
val trainingLabeledPoints = readFile(inputTrainPath)
|
||||
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
|
||||
val testLabeledPoints = readFile(inputTestPath).iterator
|
||||
val testMatrix = new DMatrix(testLabeledPoints, null)
|
||||
val booster = XGBoost.train(trainRDD,
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
|
||||
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
|
||||
val inputTrainPath = args(1)
|
||||
val outputModelPath = args(2)
|
||||
// number of iterations
|
||||
val numRound = args(0).toInt
|
||||
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
|
||||
// training parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val model = XGBoost.train(trainRDD, paramMap, numRound)
|
||||
// save model to HDFS path
|
||||
model.saveModelToHadoop(outputModelPath)
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,13 +24,12 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
|
||||
object DataUtils extends Serializable {
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPointsAsJava(
|
||||
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
|
||||
fromSparkToXGBoostLabeledPoints(sps).asJava
|
||||
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
|
||||
: java.util.Iterator[LabeledPoint] = {
|
||||
fromSparkPointsToXGBoostPoints(sps).asJava
|
||||
}
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
||||
implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[SparkLabeledPoint]):
|
||||
Iterator[LabeledPoint] = {
|
||||
for (p <- sps) yield {
|
||||
p.features match {
|
||||
@ -42,4 +41,21 @@ object DataUtils extends Serializable {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
implicit def fromSparkVectorToXGBoostPointsJava(sps: Iterator[Vector])
|
||||
: java.util.Iterator[LabeledPoint] = {
|
||||
fromSparkVectorToXGBoostPoints(sps).asJava
|
||||
}
|
||||
implicit def fromSparkVectorToXGBoostPoints(sps: Iterator[Vector])
|
||||
: Iterator[LabeledPoint] = {
|
||||
for (p <- sps) yield {
|
||||
p match {
|
||||
case denseFeature: DenseVector =>
|
||||
LabeledPoint.fromDenseVector(0.0f, denseFeature.values.map(_.toFloat))
|
||||
case sparseFeature: SparseVector =>
|
||||
LabeledPoint.fromSparseVector(0.0f, sparseFeature.indices,
|
||||
sparseFeature.values.map(_.toFloat))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,6 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import scala.collection.mutable
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
@ -28,7 +32,6 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
|
||||
var boosters: RDD[Booster] = null
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
@ -38,7 +41,7 @@ object XGBoost extends Serializable {
|
||||
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingData: RDD[LabeledPoint],
|
||||
xgBoostConfMap: Map[String, AnyRef],
|
||||
xgBoostConfMap: Map[String, Any],
|
||||
rabitEnv: mutable.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
@ -54,13 +57,13 @@ object XGBoost extends Serializable {
|
||||
}.cache()
|
||||
}
|
||||
|
||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
|
||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
|
||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||
val numWorkers = trainingData.partitions.length
|
||||
val sc = trainingData.sparkContext
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
val boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||
@volatile var booster: Booster = null
|
||||
val sparkJobThread = new Thread() {
|
||||
@ -74,7 +77,7 @@ object XGBoost extends Serializable {
|
||||
logger.info(s"Rabit returns with exit code $returnVal")
|
||||
if (returnVal == 0) {
|
||||
booster = boosters.first()
|
||||
Some(booster)
|
||||
Some(booster).get
|
||||
} else {
|
||||
try {
|
||||
if (sparkJobThread.isAlive) {
|
||||
@ -84,7 +87,21 @@ object XGBoost extends Serializable {
|
||||
case ie: InterruptedException =>
|
||||
logger.info("spark job thread is interrupted")
|
||||
}
|
||||
None
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
|
||||
new XGBoostModel(
|
||||
SXGBoost.loadModel(
|
||||
FileSystem
|
||||
.get(new Configuration)
|
||||
.open(new Path(modelPath))))
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,15 +16,20 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
|
||||
class XGBoostModel(booster: Booster) extends Serializable {
|
||||
|
||||
def predict(testSet: RDD[SparkLabeledPoint]): RDD[Array[Array[Float]]] = {
|
||||
/**
|
||||
* Predict result given testRDD
|
||||
* @param testSet the testSet of Data vectors
|
||||
* @return The predicted RDD
|
||||
*/
|
||||
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
||||
import DataUtils._
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(booster)
|
||||
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
|
||||
@ -37,4 +42,15 @@ class XGBoostModel(booster: Booster) extends Serializable {
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(testSet)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModelToHadoop(modelPath: String): Unit = {
|
||||
booster.saveModel(FileSystem
|
||||
.get(new Configuration)
|
||||
.create(new Path(modelPath)))
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user