From fb41e4e6735b7cea9feb604d69f88216f1b4580e Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 22 Dec 2015 03:29:20 -0600 Subject: [PATCH] spark with new labeledpoint fix import order --- .../xgboost4j/scala/spark/DataUtils.scala | 55 ++++++------------- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 10 ++-- .../xgboost4j/scala/spark/XGBoostModel.scala | 11 ++-- 3 files changed, 26 insertions(+), 50 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala index 4ad951567..12fb545c9 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala @@ -16,17 +16,28 @@ package ml.dmlc.xgboost4j.scala.spark -import java.util.{Iterator => JIterator} - -import scala.collection.mutable.ListBuffer import scala.collection.JavaConverters._ -import ml.dmlc.xgboost4j.java.DataBatch import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint} + +import ml.dmlc.xgboost4j.LabeledPoint private[spark] object DataUtils extends Serializable { + implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]): + java.util.Iterator[LabeledPoint] = { + (for (p <- sps) yield { + p.features match { + case denseFeature: DenseVector => + LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat)) + case sparseFeature: SparseVector => + LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices, + sparseFeature.values.map(_.toFloat)) + } + }).asJava + } + private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = { (sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList) } @@ -37,38 +48,4 @@ private[spark] object DataUtils extends Serializable { case sparseFeature: SparseVector => fetchUpdateFromSparseVector(sparseFeature) } - - def fromLabeledPointsToSparseMatrix(points: Iterator[LabeledPoint]): JIterator[DataBatch] = { - // TODO: support weight - var samplePos = 0 - // TODO: change hard value - val loadingBatchSize = 100 - val rowOffset = new ListBuffer[Long] - val label = new ListBuffer[Float] - val featureIndices = new ListBuffer[Int] - val featureValues = new ListBuffer[Float] - val dataBatches = new ListBuffer[DataBatch] - for (point <- points) { - val (nonZeroIndices, nonZeroValues) = fetchUpdateFromVector(point.features) - rowOffset(samplePos) = rowOffset.size - label(samplePos) = point.label.toFloat - for (i <- nonZeroIndices.indices) { - featureIndices += nonZeroIndices(i) - featureValues += nonZeroValues(i) - } - samplePos += 1 - if (samplePos % loadingBatchSize == 0) { - // create a data batch - dataBatches += new DataBatch( - rowOffset.toArray.clone(), - null, label.toArray.clone(), featureIndices.toArray.clone(), - featureValues.toArray.clone()) - rowOffset.clear() - label.clear() - featureIndices.clear() - featureValues.clear() - } - } - dataBatches.iterator.asJava - } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8151e6ccc..49417e2c6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -17,15 +17,15 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.immutable.HashMap -import scala.collection.JavaConverters._ import com.typesafe.config.Config -import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} -import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.spark.SparkContext import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} +import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} + object XGBoost { private var _sc: Option[SparkContext] = None @@ -36,6 +36,7 @@ object XGBoost { def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = { + import DataUtils._ val sc = trainingData.sparkContext val dataUtilsBroadcast = sc.broadcast(DataUtils) val filePath = config.getString("inputPath") // configuration entry name to be fixed @@ -45,8 +46,7 @@ object XGBoost { val xgBoostConfigMap = new HashMap[String, AnyRef]() val boosters = trainingData.repartition(numWorkers).mapPartitions { trainingSamples => - val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples) - val dMatrix = new DMatrix(new JDMatrix(dataBatches, null)) + val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)) }.cache() // force the job diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 47efb053f..d09e43969 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -16,21 +16,20 @@ package ml.dmlc.xgboost4j.scala.spark -import scala.collection.JavaConverters._ +import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint} +import org.apache.spark.rdd.RDD import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.rdd.RDD class XGBoostModel(booster: Booster) extends Serializable { - def predict(testSet: RDD[LabeledPoint]): RDD[Array[Array[Float]]] = { + def predict(testSet: RDD[SparkLabeledPoint]): RDD[Array[Array[Float]]] = { + import DataUtils._ val broadcastBooster = testSet.sparkContext.broadcast(booster) val dataUtils = testSet.sparkContext.broadcast(DataUtils) testSet.mapPartitions { testSamples => - val dataBatches = dataUtils.value.fromLabeledPointsToSparseMatrix(testSamples) - val dMatrix = new DMatrix(new JDMatrix(dataBatches, null)) + val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) Iterator(broadcastBooster.value.predict(dMatrix)) } }