example of DistTrainWithSpark and trigger job with foreachPartition

This commit is contained in:
CodingCat
2016-03-06 10:16:11 -05:00
parent f768edfede
commit 808e30f9fc
13 changed files with 588 additions and 867 deletions

View File

@@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
import ml.dmlc.xgboost4j.LabeledPoint
private[spark] object DataUtils extends Serializable {
object DataUtils extends Serializable {
implicit def fromSparkToXGBoostLabeledPointsAsJava(
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
fromSparkToXGBoostLabeledPoints(sps).asJava
}
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
java.util.Iterator[LabeledPoint] = {
(for (p <- sps) yield {
Iterator[LabeledPoint] = {
for (p <- sps) yield {
p.features match {
case denseFeature: DenseVector =>
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
@@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
sparseFeature.values.map(_.toFloat))
}
}).asJava
}
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
}
private def fetchUpdateFromVector(feature: Vector) = feature match {
case denseFeature: DenseVector =>
fetchUpdateFromSparseVector(denseFeature.toSparse)
case sparseFeature: SparseVector =>
fetchUpdateFromSparseVector(sparseFeature)
}
}
}

View File

@@ -61,7 +61,8 @@ object XGBoost extends Serializable {
require(tracker.start(), "FAULT: Failed to start tracker")
boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval)
// force the job
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
boosters.foreachPartition(_ => ())
println("=====finished training=====")
val booster = boosters.first()
val returnVal = tracker.waitFor()
logger.info(s"Rabit returns with exit code $returnVal")