example of DistTrainWithSpark and trigger job with foreachPartition
This commit is contained in:
@@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
|
||||
private[spark] object DataUtils extends Serializable {
|
||||
object DataUtils extends Serializable {
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPointsAsJava(
|
||||
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
|
||||
fromSparkToXGBoostLabeledPoints(sps).asJava
|
||||
}
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
||||
java.util.Iterator[LabeledPoint] = {
|
||||
(for (p <- sps) yield {
|
||||
Iterator[LabeledPoint] = {
|
||||
for (p <- sps) yield {
|
||||
p.features match {
|
||||
case denseFeature: DenseVector =>
|
||||
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
|
||||
@@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
|
||||
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
|
||||
sparseFeature.values.map(_.toFloat))
|
||||
}
|
||||
}).asJava
|
||||
}
|
||||
|
||||
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
|
||||
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
|
||||
}
|
||||
|
||||
private def fetchUpdateFromVector(feature: Vector) = feature match {
|
||||
case denseFeature: DenseVector =>
|
||||
fetchUpdateFromSparseVector(denseFeature.toSparse)
|
||||
case sparseFeature: SparseVector =>
|
||||
fetchUpdateFromSparseVector(sparseFeature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,8 @@ object XGBoost extends Serializable {
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval)
|
||||
// force the job
|
||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
||||
boosters.foreachPartition(_ => ())
|
||||
println("=====finished training=====")
|
||||
val booster = boosters.first()
|
||||
val returnVal = tracker.waitFor()
|
||||
logger.info(s"Rabit returns with exit code $returnVal")
|
||||
|
||||
Reference in New Issue
Block a user