spark with new labeledpoint

fix import order
This commit is contained in:
CodingCat 2015-12-22 03:29:20 -06:00
parent 74bda4bfc5
commit fb41e4e673
3 changed files with 26 additions and 50 deletions

View File

@ -16,17 +16,28 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import java.util.{Iterator => JIterator}
import scala.collection.mutable.ListBuffer
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.DataBatch
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector} import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
import ml.dmlc.xgboost4j.LabeledPoint
private[spark] object DataUtils extends Serializable { private[spark] object DataUtils extends Serializable {
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
java.util.Iterator[LabeledPoint] = {
(for (p <- sps) yield {
p.features match {
case denseFeature: DenseVector =>
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
case sparseFeature: SparseVector =>
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
sparseFeature.values.map(_.toFloat))
}
}).asJava
}
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = { private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList) (sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
} }
@ -37,38 +48,4 @@ private[spark] object DataUtils extends Serializable {
case sparseFeature: SparseVector => case sparseFeature: SparseVector =>
fetchUpdateFromSparseVector(sparseFeature) fetchUpdateFromSparseVector(sparseFeature)
} }
def fromLabeledPointsToSparseMatrix(points: Iterator[LabeledPoint]): JIterator[DataBatch] = {
// TODO: support weight
var samplePos = 0
// TODO: change hard value
val loadingBatchSize = 100
val rowOffset = new ListBuffer[Long]
val label = new ListBuffer[Float]
val featureIndices = new ListBuffer[Int]
val featureValues = new ListBuffer[Float]
val dataBatches = new ListBuffer[DataBatch]
for (point <- points) {
val (nonZeroIndices, nonZeroValues) = fetchUpdateFromVector(point.features)
rowOffset(samplePos) = rowOffset.size
label(samplePos) = point.label.toFloat
for (i <- nonZeroIndices.indices) {
featureIndices += nonZeroIndices(i)
featureValues += nonZeroValues(i)
}
samplePos += 1
if (samplePos % loadingBatchSize == 0) {
// create a data batch
dataBatches += new DataBatch(
rowOffset.toArray.clone(),
null, label.toArray.clone(), featureIndices.toArray.clone(),
featureValues.toArray.clone())
rowOffset.clear()
label.clear()
featureIndices.clear()
featureValues.clear()
}
}
dataBatches.iterator.asJava
}
} }

View File

@ -17,15 +17,15 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.immutable.HashMap import scala.collection.immutable.HashMap
import scala.collection.JavaConverters._
import com.typesafe.config.Config import com.typesafe.config.Config
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost { object XGBoost {
private var _sc: Option[SparkContext] = None private var _sc: Option[SparkContext] = None
@ -36,6 +36,7 @@ object XGBoost {
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null, def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
eval: EvalTrait = null): XGBoostModel = { eval: EvalTrait = null): XGBoostModel = {
import DataUtils._
val sc = trainingData.sparkContext val sc = trainingData.sparkContext
val dataUtilsBroadcast = sc.broadcast(DataUtils) val dataUtilsBroadcast = sc.broadcast(DataUtils)
val filePath = config.getString("inputPath") // configuration entry name to be fixed val filePath = config.getString("inputPath") // configuration entry name to be fixed
@ -45,8 +46,7 @@ object XGBoost {
val xgBoostConfigMap = new HashMap[String, AnyRef]() val xgBoostConfigMap = new HashMap[String, AnyRef]()
val boosters = trainingData.repartition(numWorkers).mapPartitions { val boosters = trainingData.repartition(numWorkers).mapPartitions {
trainingSamples => trainingSamples =>
val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples) val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)) Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
}.cache() }.cache()
// force the job // force the job

View File

@ -16,21 +16,20 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
class XGBoostModel(booster: Booster) extends Serializable { class XGBoostModel(booster: Booster) extends Serializable {
def predict(testSet: RDD[LabeledPoint]): RDD[Array[Array[Float]]] = { def predict(testSet: RDD[SparkLabeledPoint]): RDD[Array[Array[Float]]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(booster) val broadcastBooster = testSet.sparkContext.broadcast(booster)
val dataUtils = testSet.sparkContext.broadcast(DataUtils) val dataUtils = testSet.sparkContext.broadcast(DataUtils)
testSet.mapPartitions { testSamples => testSet.mapPartitions { testSamples =>
val dataBatches = dataUtils.value.fromLabeledPointsToSparseMatrix(testSamples) val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
Iterator(broadcastBooster.value.predict(dMatrix)) Iterator(broadcastBooster.value.predict(dMatrix))
} }
} }