[Spark] Refactor train, predict, add save
This commit is contained in:
@@ -24,13 +24,12 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
|
||||
object DataUtils extends Serializable {
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPointsAsJava(
|
||||
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
|
||||
fromSparkToXGBoostLabeledPoints(sps).asJava
|
||||
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
|
||||
: java.util.Iterator[LabeledPoint] = {
|
||||
fromSparkPointsToXGBoostPoints(sps).asJava
|
||||
}
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
||||
implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[SparkLabeledPoint]):
|
||||
Iterator[LabeledPoint] = {
|
||||
for (p <- sps) yield {
|
||||
p.features match {
|
||||
@@ -42,4 +41,21 @@ object DataUtils extends Serializable {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
implicit def fromSparkVectorToXGBoostPointsJava(sps: Iterator[Vector])
|
||||
: java.util.Iterator[LabeledPoint] = {
|
||||
fromSparkVectorToXGBoostPoints(sps).asJava
|
||||
}
|
||||
implicit def fromSparkVectorToXGBoostPoints(sps: Iterator[Vector])
|
||||
: Iterator[LabeledPoint] = {
|
||||
for (p <- sps) yield {
|
||||
p match {
|
||||
case denseFeature: DenseVector =>
|
||||
LabeledPoint.fromDenseVector(0.0f, denseFeature.values.map(_.toFloat))
|
||||
case sparseFeature: SparseVector =>
|
||||
LabeledPoint.fromSparseVector(0.0f, sparseFeature.indices,
|
||||
sparseFeature.values.map(_.toFloat))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import scala.collection.mutable
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
@@ -28,7 +32,6 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
|
||||
var boosters: RDD[Booster] = null
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
@@ -38,7 +41,7 @@ object XGBoost extends Serializable {
|
||||
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingData: RDD[LabeledPoint],
|
||||
xgBoostConfMap: Map[String, AnyRef],
|
||||
xgBoostConfMap: Map[String, Any],
|
||||
rabitEnv: mutable.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
@@ -54,13 +57,13 @@ object XGBoost extends Serializable {
|
||||
}.cache()
|
||||
}
|
||||
|
||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
|
||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
|
||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
||||
val numWorkers = trainingData.partitions.length
|
||||
val sc = trainingData.sparkContext
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
val boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||
@volatile var booster: Booster = null
|
||||
val sparkJobThread = new Thread() {
|
||||
@@ -74,7 +77,7 @@ object XGBoost extends Serializable {
|
||||
logger.info(s"Rabit returns with exit code $returnVal")
|
||||
if (returnVal == 0) {
|
||||
booster = boosters.first()
|
||||
Some(booster)
|
||||
Some(booster).get
|
||||
} else {
|
||||
try {
|
||||
if (sparkJobThread.isAlive) {
|
||||
@@ -84,7 +87,21 @@ object XGBoost extends Serializable {
|
||||
case ie: InterruptedException =>
|
||||
logger.info("spark job thread is interrupted")
|
||||
}
|
||||
None
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
|
||||
new XGBoostModel(
|
||||
SXGBoost.loadModel(
|
||||
FileSystem
|
||||
.get(new Configuration)
|
||||
.open(new Path(modelPath))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,15 +16,20 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
|
||||
class XGBoostModel(booster: Booster) extends Serializable {
|
||||
|
||||
def predict(testSet: RDD[SparkLabeledPoint]): RDD[Array[Array[Float]]] = {
|
||||
/**
|
||||
* Predict result given testRDD
|
||||
* @param testSet the testSet of Data vectors
|
||||
* @return The predicted RDD
|
||||
*/
|
||||
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
||||
import DataUtils._
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(booster)
|
||||
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
|
||||
@@ -37,4 +42,15 @@ class XGBoostModel(booster: Booster) extends Serializable {
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(testSet)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModelToHadoop(modelPath: String): Unit = {
|
||||
booster.saveModel(FileSystem
|
||||
.get(new Configuration)
|
||||
.create(new Path(modelPath)))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user