distributed in RDD

This commit is contained in:
CodingCat
2016-03-05 17:50:40 -05:00
parent fb41e4e673
commit 5c1af13f84
4 changed files with 116 additions and 24 deletions

View File

@@ -28,27 +28,35 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost {
private var _sc: Option[SparkContext] = None
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster)
}
private[spark] def buildDistributedBoosters(
trainingData: RDD[LabeledPoint],
xgBoostConfMap: Map[String, AnyRef],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
import DataUtils._
val sc = trainingData.sparkContext
val dataUtilsBroadcast = sc.broadcast(DataUtils)
trainingData.repartition(numWorkers).mapPartitions {
trainingSamples =>
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
Iterator(SXGBoost.train(xgBoostConfMap, dMatrix, round,
watches = new HashMap[String, DMatrix], obj, eval))
}.cache()
}
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
eval: EvalTrait = null): XGBoostModel = {
import DataUtils._
val sc = trainingData.sparkContext
val dataUtilsBroadcast = sc.broadcast(DataUtils)
val filePath = config.getString("inputPath") // configuration entry name to be fixed
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
val sc = trainingData.sparkContext
// TODO: build configuration map from config
val xgBoostConfigMap = new HashMap[String, AnyRef]()
val boosters = trainingData.repartition(numWorkers).mapPartitions {
trainingSamples =>
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
}.cache()
val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
obj, eval)
// force the job
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
// TODO: how to choose best model